# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as clrs
from custEM.misc import pyhed_calculations as phc
import custEM.misc
import sys
import os
import json
from matplotlib import cm
from matplotlib.colors import LogNorm
from matplotlib import rcParams

Utility functions for visualization of custEM results

[docs] class PlotBase: def __init__(self): if sys.version_info < (3, 0): global FileNotFoundError FileNotFoundError = IOError self.line_data, self.line_coords = dict(), dict() self.point_data, self.point_coords = dict(), dict() self.slice_data, self.slice_coords = dict(), dict() self.rel_comp_errors, self.rel_mag_errors = dict(), dict() self.abs_comp_errors, self.abs_mag_errors = dict(), dict() self.ratio_comp_errors, self.ratio_mag_errors = dict(), dict() self.missing = [] self.fig_size = None self.dpi = 300 self.fs = 12 self.dg_space = False self.label_color = '#000000' self.c_idx = dict({'x': 0, 'y': 1, 'z': 2}) self.r_dir = 'results' self.s_dir = 'plots'
[docs] def init_main_parameters(self, mod, mesh, approach, path): """ Initializes the most important class attributes. """ if mod is None: mod = self.mod if mesh is None: mesh = self.mesh if approach is None: approach = self.approach # hack to be compatible with deprecated directory structure if path is None: if os.path.isdir(self.r_dir + '/' + approach + '/' + mesh): path = (self.r_dir + '/' + approach + '/' + mesh + '/' + mod + '_interpolated/') else: path = (self.r_dir + '/' + approach + '/' + mesh + '_results/' + mod + '_interpolated/') return(mod, mesh, approach, path)
[docs] def general_import(self, path, quantity, mesh, key, comp=None, stride=1, ckey=None): """ General import functions for interpolated results. """ if key is None: data_key = self.mod + '_' + quantity + '_' + mesh else: data_key = key + '_' + quantity try: data = np.load(path + quantity + '_' + mesh + '.npy') except FileNotFoundError: try: # support deprecated name convention for a while data = np.load(path + quantity + '_on_' + mesh + '.npy') except FileNotFoundError: self.print_import_error(quantity + '_', mesh, path) if comp is None: data_s = self.rearange_point_data(data) self.point_data.update({data_key: data_s[:, :3]}) self.point_coords.update({points: data_s[:, 3:]}) elif type(comp) is not list: data_s = self.arrange_line_data(data, comp, stride) self.line_data.update({data_key: data_s[:, :3]}) self.line_coords.update({mesh: data_s[:, 3:]}) else: data_s = self.reduce_slice_data(data, comp[0], comp[1], stride) self.slice_data.update({data_key: data_s[:, :3]}) self.slice_coords.update({ckey: data_s[:, 3:]})
[docs] def init_component_integers(self, string, line=True): """ Initializes integer mapping to handel x-, y-, and z-directed line and slice coordinates. """ if string == 'x': if line: return(3) else: return(5, 4) elif string == 'y': if line: return(4) else: return(5, 3) elif string == 'z': if line: return(5) else: return(3, 4) else: print("Error!, don't know how to define coordinates! Support for " "lines or slices without coordinate ending in the name " "(x,y,z) is not supportet yet!") raise SystemExit
[docs] def calc_pyhed_reference(self, key, EH='E', line=None, mod=None, config_file=None, mesh=None): """ Deprecated! Might be edited for future purposes. """ if mod is None: mod = self.mod if mesh is None: mesh = self.mesh if config_file is None: # hack to be compatible with deprecated directory structure try: config_file = (self.r_dir + '/' + self.approach + '/' + mesh + '/' + mod + '_config.json') except Exception as e: print(e) config_file = (self.r_dir + '/' + self.approach + '/' + mesh + '_results/' + mod + '_config.json') if line is None: line = [key for key in self.line_coords][0] coords, cc = self.get_coords(line, for_pyhed=True) Calculator = phc.PHC(config_file) return(Calculator.calc_reference(coords, EH))
[docs] def load_model_parameters(self, mod=None, mesh=None, approach=None): """ Deprecated! Might be edited for future purposes. """ if mod is None: mod = self.mod if mesh is None: mesh = self.mesh if approach is None: approach = self.approach # hack to be compatible with deprecated directory structure try: directory = self.r_dir + '/' + self.approach + '/' + mesh + '/' except Exception as e: print(e) directory = self.r_dir + '/' + self.approach + '/' +\ mesh + '_results/' try: self.__dict__.update(json.load(open(directory + mod + '_config.json'))) self.__dict__.update(json.load(open(directory + mod + '_resource.json'))) except FileNotFoundError: print('Warning! Config files not found')
[docs] def eval_properties(self, mod, key, label, EH, sf, mesh, line=True, points=False): """ Evalute names, labels, field qunatities etc. for all kinds of plots. """ if mod is None: mod = self.mod if label is None: label = mod if sf: ts = 's' else: ts = 't' if not points: if line: if mesh is None: mesh = [key for key in self.line_coords][0] coords, cc = self.get_coords(mesh) else: if mesh is None: mesh = [key for key in self.slice_coords][0] coords = self.get_coords(mesh) cc = mesh[-1] EH_list = ['E', 'H'] if key is None: keys = [mod + '_' + 'E' + '_' + ts + '_' + mesh, mod + '_' + 'H' + '_' + ts + '_' + mesh] else: keys = [key + '_' + 'E' + '_' + ts, key + '_' + 'H' + '_' + ts] if EH == 'E' or EH == ['E']: EH_list = [EH_list[0]] keys = [keys[0]] if EH == 'H' or EH == ['H']: EH_list = [EH_list[1]] keys = [keys[1]] if points: coords = np.arange(len(self.point_data[keys[0]])) cc = 'Station No.' return(keys, label, EH_list, coords, cc)
[docs] def reduce_slice_data(self, data, comp1, comp2, step): """ Apply a stride for importing slice datasets to significantly reduce the amount of data to be plotted later on, if the interpolation mesh was set too dense. """ if self.dg_space: nn, ri = np.unique(data[:, 3:], axis=0, return_index=True) data = data[ri, :] data_sss = np.zeros((1, 6)) + 1j * np.zeros((1, 6)) n_x = len(np.unique(data[:, comp1])) n_y = len(np.unique(data[:, comp2])) data = data[data[:, comp1].argsort()] for j in range(0, n_x, step): temp = np.copy(data[j * n_y: (j + 1) * n_y]) data_ss = temp[temp[:, comp2].argsort()][::step, :] data_sss = np.vstack((data_sss, data_ss)) return(data_sss[1:, :])
[docs] def arrange_line_data(self, data, comp, step): if self.dg_space: nn, ri = np.unique(data[:, 3:], axis=0, return_index=True) data = data[ri, :] data_s = np.array(data[data[:, comp].argsort()]) return(data_s[::step, :])
[docs] def rearange_point_data(self, data): if self.dg_space: nn, ri = np.unique(data[:, 3:], axis=0, return_index=True) return (data[ri, :]) else: return(data)
[docs] def print_import_error(self, field_type, name, path): """ Print an import error if a specified dataset for visulaization could not be found. """ print('Warning! ' + field_type + name + ' could not be found in path: \n' + path + ', continuing...') self.missing.append(field_type + name)
[docs] def get_coords(self, name, for_pyhed=False): """ Initialize coordinates (km) for visualization, depending on the direction of the input line- or slice-mesh. """ if 'line' in name: if not for_pyhed: if name[-1] == 'x': return(self.line_coords[name][:, 0].real, 'x') if name[-1] == 'y': return(self.line_coords[name][:, 1].real, 'y') if name[-1] == 'z': return(self.line_coords[name][:, 2].real, 'z') else: return(self.line_coords[name].real, 'xyz') elif 'slice' in name: if not for_pyhed: if name[-1] == 'x': M = np.unique(self.slice_coords[name][:, 1].real)/1e3 N = np.unique(self.slice_coords[name][:, 2].real)/1e3 elif name[-1] == 'y': M = np.unique(self.slice_coords[name][:, 2].real)/1e3 N = np.unique(self.slice_coords[name][:, 0].real)/1e3 elif name[-1] == 'z': M = np.unique(self.slice_coords[name][:, 0].real)/1e3 N = np.unique(self.slice_coords[name][:, 1].real)/1e3 return(M, N) else: return(self.slice_coords[name].real, 'xyz') else: print('Error! Cannot identify coordinates!') raise SystemExit
[docs] def init_cmap(self, cmap=None, var='E'): """ Initialize colormaps for 2D / 3D data visualization. """ if cmap is not None: if cmap == 'magma': return( elif cmap == 'viridis': return( elif cmap == 'RdBu_r': return( else: print('Error! cmap name "' + cmap + '" is not supported yet!') raise SystemExit else: if var == 'E': return( elif var == 'H': return( elif var == 'err': return( else: print('Error! If "cmap=None", "var" must be specified!') print('(error in init_cmap method from Plot class') raise SystemExit
[docs] def make_subfigure_box(height=3, width=2, fs=12, var=['E', 'E', 'E'], sf=False, log_scale=True, cc='x', ylim=None, err_plot=None, sharex=True, sharey=False, xlim=None, grid=True, sliceplot=False, ap=False, add_km=True, clr=None): """ Initialize default fig/ax objects with six or eight subfigures for all kinds of line- and slice-data plots. """ if var == 'E' or 'E' in var: var = ['E', 'E', 'E'] var2 = ' (V/m)' elif var == 'H' or 'H' in var: var = ['H', 'H', 'H'] var2 = ' (A/m)' else: var2 = ' WHICH UNIT ???' f, ax = plt.subplots(height, width, sharex=sharex, sharey=sharey) if sliceplot: grid = False if cc == 'x': cc, cc2 = 'y', 'z' elif cc == 'y': cc, cc2 = 'x', 'z' elif cc == 'z': cc, cc2 = 'x', 'y' if sf: list2 = ['$^s_x$', '$^s_y$', '$^s_z$', '$^s$'] list3 = ['x', 'y', 'z', ''] else: list2 = ['$_x$', '$_y$', '$_z$', ''] list3 = ['x', 'y', 'z', ''] for i in range(height): ax_labels = [ r'$\Re$(' + var[i] + list2[i] + ')' + var2, r'$\Im$(' + var[i] + list2[i] + ')' + var2] if err_plot is not None: ax_labels = [r'$\epsilon_{\Re(\mathrm{' + var[i] + '}' + list3[i] + ')}$ (%)', r'$\epsilon_{\Im(\mathrm{' + var[i] + '}' + list3[i] + ')}$ (%)'] if ap: ax_labels = [r'||$\mathbf{' + var[i] + '}' + list2[i] + '||$ ' + var2, r'$\phi$($\mathbf{' + var[i] + '}' + list2[i] + ')$' + ' (°)'] if ap is True and err_plot == 'diff': ax_labels = [r'$\mathbf{\epsilon}$ (||$\mathbf{' + var[i] + '}' + list2[i] + '||$)' + var2, r'$\Delta\phi$($\mathbf{' + var[i] + '}' + list2[i] + ')$' + ' (°)'] if err_plot is not None and i == 3: ax_labels = [ r'$\Re$($\mathbf{|' + var[i] + '|}' + list2[i] + ')$' + var2, r'$\Im$($\mathbf{|' + var[i] + '|}' + list2[i] + ')$' + var2] for j in range(width): if add_km: ax[-1, j].set_xlabel(cc + ' (km)', size=fs) else: ax[-1, j].set_xlabel(cc, size=fs) if log_scale: ax[i, j].set_yscale('log') if ap: ax[i, 1].set_yscale('linear') if ylim is not None: ax[i, j].set_ylim(ylim) if ap: ax[i, 1].set_ylim([-180.1, 180.1]) if xlim is not None: ax[i, j].set_xlim(xlim) if not sliceplot: ax[i, j].set_ylabel(ax_labels[j], size=fs) else: ax[i, j].set_ylabel(cc2 + ' (km)', size=fs) ax[i, j].set_title(ax_labels[j], size=fs) if j == 1: ax[i, 1].yaxis.set_label_position("right") ax[i, 1].yaxis.tick_right() if err_plot is None: if ap: ax[i, 1].set_yticks([-180., -90., 0., 90., 180.]) ax[i, 1].yaxis.set_ticklabels(['180', '-90', '0', '90', '180']) if err_plot is not None: if err_plot == 'diff': if not sliceplot: ax[i, j].set_ylabel(ax_labels[j], size=fs) else: ax[i, j].set_ylabel(cc2 + ' (km)', size=fs) elif err_plot == 'ratio': if not sliceplot: ax[i, j].set_ylabel(ax_labels[j] + ' [ratio]', size=fs) else: ax[i, j].set_ylabel(cc2 + ' (km)', size=fs) else: print('oO, this should not happen!') raise SystemExit ax[i, j].tick_params(labelsize=fs) if grid: ax[i, j].grid(which='major', color='0.8', ls=':') return(f, ax)
[docs] def make_plain_subfigure_box(height=3, width=2, fs=12, log_scale=True, sharex=True, sharey=True, coord_axis=True): """ Create a plain subfigure box for custom illustrations, also used by the default plot functions. """ f, ax = plt.subplots(height, width, sharex=sharex, sharey=sharey) if height == 1: ax = np.expand_dims(ax, axis=0) for j in range(width): ax[height - 1, j].set_xlabel('x (km)', size=fs) for i in range(height): if coord_axis: ax[i, 0].set_ylabel('y (km)', size=fs) if log_scale: for j in range(width): ax[i, j].set_yscale('log') return(f, ax)
[docs] def adjust_subfigure_box_axes(ax): """ As the title and the comment says... """ # NOT WORKING YET; DON'T KNOW WHY! labels = ax[1, 0].get_xticklabels() ticks = ax[1, 0].get_xticks() labs = [] for a in labels: labs.append(a.get_text()) for j in range(len(ax)): ax[0, j].set_yticks(ticks) ax[0, j].set_yticklabels(labs)
[docs] def adjust_ticks_and_labels(ax, remove_top_right=False): for i in range(ax.shape[0]): for j in range(ax.shape[1]): xlim = ax[i, j].get_xlim() ax[i, j].set_xticks(np.linspace(xlim[0], xlim[1], 5)) ylim = ax[i, j].get_ylim() ax[i, j].set_yticks(np.linspace(ylim[0], ylim[1], 5)) if remove_top_right: for i in range(ax.shape[0]): for j in range(1, ax.shape[1]): ax[i, j].yaxis.set_label_position("left") ax[i, j].set_ylabel('')
[docs] def adjust_axes(ax, equal_axis, x_lim=None, y_lim=None, size=3, width=2): """ Adjust axis properties (equal resolution for x- and y-coordinates). """ for j in range(size): for k in range(width): if x_lim is not None: ax[j, k].set_xlim(x_lim) if y_lim is not None: ax[j, k].set_ylim(y_lim) if equal_axis: ax[j, k].axis('equal') try: ax[j, k].set_adjustable('box') except exception as e: ax[j, k].set_adjustable('box-forced')
[docs] def adjust_log_ticks(ax, minlog, maxlog, symlog=False, additional=0): """ Improve style of ticks and labels of a logarithmic colorbar. """ tickrange = np.arange(minlog, maxlog + additional, 1) labels = ['${10^{%d}}$' % (exp) for exp in tickrange[:]] ax.set_ticks(10**tickrange) if symlog: array = np.append(-10**tickrange[::-1], 10**tickrange) ax.set_ticks(array) labels = np.append(['-${10^{%d}}$' % (exp) for exp in tickrange[::-1]], labels) ax.set_ticklabels(labels)
[docs] def eval_colors(data=None, n_colors=101, c_lim=None, quant=99, symlog=False): """ Utility function for properly evaluating colorbar ranges using 99 % quantiles. """ if symlog: n_colors = int(n_colors/2) if c_lim is None: c_lim = [-np.percentile(np.hstack(( -np.abs(data.real), -np.abs(data.imag))), quant), np.percentile(np.hstack(( np.abs(data.real), np.abs(data.imag))), quant)] minlog = np.round(np.log10(c_lim[0]) - 0.5) maxlog = np.round(np.log10(c_lim[1]) + 0.5) else: minlog, maxlog = np.log10(c_lim[0]), np.log10(c_lim[1]) colors = np.logspace(minlog, maxlog, n_colors) # colors = np.append(colors, np.array(1e99)) # colors = np.append(np.array(1e-90), colors) # colors = np.append(np.array(1e-95), colors) # colors = np.append(np.array(1e-99), colors) if not symlog: pass else: colors = np.append(-colors[::-1], colors) return(minlog, maxlog, colors, c_lim[0], c_lim[1])