Source code for custEM.post.plot_utils

# -*- coding: utf-8 -*-
"""
@author: Rochlitz.R
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as clrs
from custEM.misc import pyhed_calculations as phc
import custEM.misc
import sys
import os
import json
from matplotlib import cm
from matplotlib.colors import LogNorm
from matplotlib import rcParams


"""
Utility functions for visualization of custEM results
"""

[docs] class PlotBase: def __init__(self): if sys.version_info < (3, 0): global FileNotFoundError FileNotFoundError = IOError self.line_data, self.line_coords = dict(), dict() self.point_data, self.point_coords = dict(), dict() self.slice_data, self.slice_coords = dict(), dict() self.rel_comp_errors, self.rel_mag_errors = dict(), dict() self.abs_comp_errors, self.abs_mag_errors = dict(), dict() self.ratio_comp_errors, self.ratio_mag_errors = dict(), dict() self.missing = [] self.fig_size = None self.dpi = 300 self.fs = 12 self.dg_space = False self.label_color = '#000000' self.c_idx = dict({'x': 0, 'y': 1, 'z': 2}) self.r_dir = 'results' self.s_dir = 'plots'
[docs] def init_main_parameters(self, mod, mesh, approach, path): """ Initializes the most important class attributes. """ if mod is None: mod = self.mod if mesh is None: mesh = self.mesh if approach is None: approach = self.approach # hack to be compatible with deprecated directory structure if path is None: if os.path.isdir(self.r_dir + '/' + approach + '/' + mesh): path = (self.r_dir + '/' + approach + '/' + mesh + '/' + mod + '_interpolated/') else: path = (self.r_dir + '/' + approach + '/' + mesh + '_results/' + mod + '_interpolated/') return(mod, mesh, approach, path)
[docs] def general_import(self, path, quantity, mesh, key, comp=None, stride=1, ckey=None): """ General import functions for interpolated results. """ if key is None: data_key = self.mod + '_' + quantity + '_' + mesh else: data_key = key + '_' + quantity try: data = np.load(path + quantity + '_' + mesh + '.npy') except FileNotFoundError: try: # support deprecated name convention for a while data = np.load(path + quantity + '_on_' + mesh + '.npy') except FileNotFoundError: self.print_import_error(quantity + '_', mesh, path) if comp is None: data_s = self.rearange_point_data(data) self.point_data.update({data_key: data_s[:, :3]}) self.point_coords.update({points: data_s[:, 3:]}) elif type(comp) is not list: data_s = self.arrange_line_data(data, comp, stride) self.line_data.update({data_key: data_s[:, :3]}) self.line_coords.update({mesh: data_s[:, 3:]}) else: data_s = self.reduce_slice_data(data, comp[0], comp[1], stride) self.slice_data.update({data_key: data_s[:, :3]}) self.slice_coords.update({ckey: data_s[:, 3:]})
[docs] def init_component_integers(self, string, line=True): """ Initializes integer mapping to handel x-, y-, and z-directed line and slice coordinates. """ if string == 'x': if line: return(3) else: return(5, 4) elif string == 'y': if line: return(4) else: return(5, 3) elif string == 'z': if line: return(5) else: return(3, 4) else: print("Error!, don't know how to define coordinates! Support for " "lines or slices without coordinate ending in the name " "(x,y,z) is not supportet yet!") raise SystemExit
[docs] def calc_pyhed_reference(self, key, EH='E', line=None, mod=None, config_file=None, mesh=None): """ Deprecated! Might be edited for future purposes. """ if mod is None: mod = self.mod if mesh is None: mesh = self.mesh if config_file is None: # hack to be compatible with deprecated directory structure try: config_file = (self.r_dir + '/' + self.approach + '/' + mesh + '/' + mod + '_config.json') except Exception as e: print(e) config_file = (self.r_dir + '/' + self.approach + '/' + mesh + '_results/' + mod + '_config.json') if line is None: line = [key for key in self.line_coords][0] coords, cc = self.get_coords(line, for_pyhed=True) Calculator = phc.PHC(config_file) return(Calculator.calc_reference(coords, EH))
[docs] def load_model_parameters(self, mod=None, mesh=None, approach=None): """ Deprecated! Might be edited for future purposes. """ if mod is None: mod = self.mod if mesh is None: mesh = self.mesh if approach is None: approach = self.approach # hack to be compatible with deprecated directory structure try: directory = self.r_dir + '/' + self.approach + '/' + mesh + '/' except Exception as e: print(e) directory = self.r_dir + '/' + self.approach + '/' +\ mesh + '_results/' try: self.__dict__.update(json.load(open(directory + mod + '_config.json'))) self.__dict__.update(json.load(open(directory + mod + '_resource.json'))) except FileNotFoundError: print('Warning! Config files not found')
[docs] def eval_properties(self, mod, key, label, EH, sf, mesh, line=True, points=False): """ Evalute names, labels, field qunatities etc. for all kinds of plots. """ if mod is None: mod = self.mod if label is None: label = mod if sf: ts = 's' else: ts = 't' if not points: if line: if mesh is None: mesh = [key for key in self.line_coords][0] coords, cc = self.get_coords(mesh) else: if mesh is None: mesh = [key for key in self.slice_coords][0] coords = self.get_coords(mesh) cc = mesh[-1] EH_list = ['E', 'H'] if key is None: keys = [mod + '_' + 'E' + '_' + ts + '_' + mesh, mod + '_' + 'H' + '_' + ts + '_' + mesh] else: keys = [key + '_' + 'E' + '_' + ts, key + '_' + 'H' + '_' + ts] if EH == 'E' or EH == ['E']: EH_list = [EH_list[0]] keys = [keys[0]] if EH == 'H' or EH == ['H']: EH_list = [EH_list[1]] keys = [keys[1]] if points: coords = np.arange(len(self.point_data[keys[0]])) cc = 'Station No.' return(keys, label, EH_list, coords, cc)
[docs] def reduce_slice_data(self, data, comp1, comp2, step): """ Apply a stride for importing slice datasets to significantly reduce the amount of data to be plotted later on, if the interpolation mesh was set too dense. """ if self.dg_space: nn, ri = np.unique(data[:, 3:], axis=0, return_index=True) data = data[ri, :] data_sss = np.zeros((1, 6)) + 1j * np.zeros((1, 6)) n_x = len(np.unique(data[:, comp1])) n_y = len(np.unique(data[:, comp2])) data = data[data[:, comp1].argsort()] for j in range(0, n_x, step): temp = np.copy(data[j * n_y: (j + 1) * n_y]) data_ss = temp[temp[:, comp2].argsort()][::step, :] data_sss = np.vstack((data_sss, data_ss)) return(data_sss[1:, :])
[docs] def arrange_line_data(self, data, comp, step): if self.dg_space: nn, ri = np.unique(data[:, 3:], axis=0, return_index=True) data = data[ri, :] data_s = np.array(data[data[:, comp].argsort()]) return(data_s[::step, :])
[docs] def rearange_point_data(self, data): if self.dg_space: nn, ri = np.unique(data[:, 3:], axis=0, return_index=True) return (data[ri, :]) else: return(data)
[docs] def print_import_error(self, field_type, name, path): """ Print an import error if a specified dataset for visulaization could not be found. """ print('Warning! ' + field_type + name + ' could not be found in path: \n' + path + ', continuing...') self.missing.append(field_type + name)
[docs] def get_coords(self, name, for_pyhed=False): """ Initialize coordinates (km) for visualization, depending on the direction of the input line- or slice-mesh. """ if 'line' in name: if not for_pyhed: if name[-1] == 'x': return(self.line_coords[name][:, 0].real, 'x') if name[-1] == 'y': return(self.line_coords[name][:, 1].real, 'y') if name[-1] == 'z': return(self.line_coords[name][:, 2].real, 'z') else: return(self.line_coords[name].real, 'xyz') elif 'slice' in name: if not for_pyhed: if name[-1] == 'x': M = np.unique(self.slice_coords[name][:, 1].real)/1e3 N = np.unique(self.slice_coords[name][:, 2].real)/1e3 elif name[-1] == 'y': M = np.unique(self.slice_coords[name][:, 2].real)/1e3 N = np.unique(self.slice_coords[name][:, 0].real)/1e3 elif name[-1] == 'z': M = np.unique(self.slice_coords[name][:, 0].real)/1e3 N = np.unique(self.slice_coords[name][:, 1].real)/1e3 return(M, N) else: return(self.slice_coords[name].real, 'xyz') else: print('Error! Cannot identify coordinates!') raise SystemExit
[docs] def init_cmap(self, cmap=None, var='E'): """ Initialize colormaps for 2D / 3D data visualization. """ if cmap is not None: if cmap == 'magma': return(plt.cm.magma) elif cmap == 'viridis': return(plt.cm.viridis) elif cmap == 'RdBu_r': return(plt.cm.RdBu_r) else: print('Error! cmap name "' + cmap + '" is not supported yet!') raise SystemExit else: if var == 'E': return(plt.cm.magma) elif var == 'H': return(plt.cm.viridis) elif var == 'err': return(plt.cm.RdBu_r) else: print('Error! If "cmap=None", "var" must be specified!') print('(error in init_cmap method from Plot class') raise SystemExit
[docs] def make_subfigure_box(height=3, width=2, fs=12, var=['E', 'E', 'E'], sf=False, log_scale=True, cc='x', ylim=None, err_plot=None, sharex=True, sharey=False, xlim=None, grid=True, sliceplot=False, ap=False, add_km=True, clr=None): """ Initialize default fig/ax objects with six or eight subfigures for all kinds of line- and slice-data plots. """ if var == 'E' or 'E' in var: var = ['E', 'E', 'E'] var2 = ' (V/m)' elif var == 'H' or 'H' in var: var = ['H', 'H', 'H'] var2 = ' (A/m)' else: var2 = ' WHICH UNIT ???' f, ax = plt.subplots(height, width, sharex=sharex, sharey=sharey) if sliceplot: grid = False if cc == 'x': cc, cc2 = 'y', 'z' elif cc == 'y': cc, cc2 = 'x', 'z' elif cc == 'z': cc, cc2 = 'x', 'y' if sf: list2 = ['$^s_x$', '$^s_y$', '$^s_z$', '$^s$'] list3 = ['x', 'y', 'z', ''] else: list2 = ['$_x$', '$_y$', '$_z$', ''] list3 = ['x', 'y', 'z', ''] for i in range(height): ax_labels = [ r'$\Re$(' + var[i] + list2[i] + ')' + var2, r'$\Im$(' + var[i] + list2[i] + ')' + var2] if err_plot is not None: ax_labels = [r'$\epsilon_{\Re(\mathrm{' + var[i] + '}' + list3[i] + ')}$ (%)', r'$\epsilon_{\Im(\mathrm{' + var[i] + '}' + list3[i] + ')}$ (%)'] if ap: ax_labels = [r'||$\mathbf{' + var[i] + '}' + list2[i] + '||$ ' + var2, r'$\phi$($\mathbf{' + var[i] + '}' + list2[i] + ')$' + ' (°)'] if ap is True and err_plot == 'diff': ax_labels = [r'$\mathbf{\epsilon}$ (||$\mathbf{' + var[i] + '}' + list2[i] + '||$)' + var2, r'$\Delta\phi$($\mathbf{' + var[i] + '}' + list2[i] + ')$' + ' (°)'] if err_plot is not None and i == 3: ax_labels = [ r'$\Re$($\mathbf{|' + var[i] + '|}' + list2[i] + ')$' + var2, r'$\Im$($\mathbf{|' + var[i] + '|}' + list2[i] + ')$' + var2] for j in range(width): if add_km: ax[-1, j].set_xlabel(cc + ' (km)', size=fs) else: ax[-1, j].set_xlabel(cc, size=fs) if log_scale: ax[i, j].set_yscale('log') if ap: ax[i, 1].set_yscale('linear') if ylim is not None: ax[i, j].set_ylim(ylim) if ap: ax[i, 1].set_ylim([-180.1, 180.1]) if xlim is not None: ax[i, j].set_xlim(xlim) if not sliceplot: ax[i, j].set_ylabel(ax_labels[j], size=fs) else: ax[i, j].set_ylabel(cc2 + ' (km)', size=fs) ax[i, j].set_title(ax_labels[j], size=fs) if j == 1: ax[i, 1].yaxis.set_label_position("right") ax[i, 1].yaxis.tick_right() if err_plot is None: if ap: ax[i, 1].set_yticks([-180., -90., 0., 90., 180.]) ax[i, 1].yaxis.set_ticklabels(['180', '-90', '0', '90', '180']) if err_plot is not None: if err_plot == 'diff': if not sliceplot: ax[i, j].set_ylabel(ax_labels[j], size=fs) else: ax[i, j].set_ylabel(cc2 + ' (km)', size=fs) elif err_plot == 'ratio': if not sliceplot: ax[i, j].set_ylabel(ax_labels[j] + ' [ratio]', size=fs) else: ax[i, j].set_ylabel(cc2 + ' (km)', size=fs) else: print('oO, this should not happen!') raise SystemExit ax[i, j].tick_params(labelsize=fs) if grid: ax[i, j].grid(which='major', color='0.8', ls=':') return(f, ax)
[docs] def make_plain_subfigure_box(height=3, width=2, fs=12, log_scale=True, sharex=True, sharey=True, coord_axis=True): """ Create a plain subfigure box for custom illustrations, also used by the default plot functions. """ f, ax = plt.subplots(height, width, sharex=sharex, sharey=sharey) if height == 1: ax = np.expand_dims(ax, axis=0) for j in range(width): ax[height - 1, j].set_xlabel('x (km)', size=fs) for i in range(height): if coord_axis: ax[i, 0].set_ylabel('y (km)', size=fs) if log_scale: for j in range(width): ax[i, j].set_yscale('log') return(f, ax)
[docs] def adjust_subfigure_box_axes(ax): """ As the title and the comment says... """ # NOT WORKING YET; DON'T KNOW WHY! labels = ax[1, 0].get_xticklabels() ticks = ax[1, 0].get_xticks() labs = [] for a in labels: labs.append(a.get_text()) for j in range(len(ax)): ax[0, j].set_yticks(ticks) ax[0, j].set_yticklabels(labs)
[docs] def adjust_ticks_and_labels(ax, remove_top_right=False): for i in range(ax.shape[0]): for j in range(ax.shape[1]): xlim = ax[i, j].get_xlim() ax[i, j].set_xticks(np.linspace(xlim[0], xlim[1], 5)) ylim = ax[i, j].get_ylim() ax[i, j].set_yticks(np.linspace(ylim[0], ylim[1], 5)) if remove_top_right: for i in range(ax.shape[0]): for j in range(1, ax.shape[1]): ax[i, j].yaxis.set_label_position("left") ax[i, j].set_ylabel('')
[docs] def adjust_axes(ax, equal_axis, x_lim=None, y_lim=None, size=3, width=2): """ Adjust axis properties (equal resolution for x- and y-coordinates). """ for j in range(size): for k in range(width): if x_lim is not None: ax[j, k].set_xlim(x_lim) if y_lim is not None: ax[j, k].set_ylim(y_lim) if equal_axis: ax[j, k].axis('equal') try: ax[j, k].set_adjustable('box') except exception as e: ax[j, k].set_adjustable('box-forced')
[docs] def adjust_log_ticks(ax, minlog, maxlog, symlog=False, additional=0): """ Improve style of ticks and labels of a logarithmic colorbar. """ tickrange = np.arange(minlog, maxlog + additional, 1) labels = ['${10^{%d}}$' % (exp) for exp in tickrange[:]] ax.set_ticks(10**tickrange) if symlog: array = np.append(-10**tickrange[::-1], 10**tickrange) ax.set_ticks(array) labels = np.append(['-${10^{%d}}$' % (exp) for exp in tickrange[::-1]], labels) ax.set_ticklabels(labels)
[docs] def eval_colors(data=None, n_colors=101, c_lim=None, quant=99, symlog=False): """ Utility function for properly evaluating colorbar ranges using 99 % quantiles. """ if symlog: n_colors = int(n_colors/2) if c_lim is None: c_lim = [-np.percentile(np.hstack(( -np.abs(data.real), -np.abs(data.imag))), quant), np.percentile(np.hstack(( np.abs(data.real), np.abs(data.imag))), quant)] minlog = np.round(np.log10(c_lim[0]) - 0.5) maxlog = np.round(np.log10(c_lim[1]) + 0.5) else: minlog, maxlog = np.log10(c_lim[0]), np.log10(c_lim[1]) colors = np.logspace(minlog, maxlog, n_colors) # colors = np.append(colors, np.array(1e99)) # colors = np.append(np.array(1e-90), colors) # colors = np.append(np.array(1e-95), colors) # colors = np.append(np.array(1e-99), colors) if not symlog: pass else: colors = np.append(-colors[::-1], colors) return(minlog, maxlog, colors, c_lim[0], c_lim[1])