# -*- coding: utf-8 -*-
import os
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from scipy.optimize import curve_fit
import scipy.stats as st
from scipy.stats.qmc import LatinHypercube
import torch
from softsensor.autoreg_models import (_postprocess_sens, _reshape_array)
from softsensor.metrics import picp
plt.rcParams.update({
# "font.family": "CMU Serif",
"axes.xmargin": 0.01,
})
# change this to the desired base path of the project
BASE_PATH = 'c:/Users/FSL5FE/MA/Coding/SoftSensorACC/use_cases/Fischer_MA'
[docs]
def plot_uncertainty(dataframe, mean, var, ground_truth=None, t_start=None, t_end=None,
n_std=2, show_legend=False):
"""
Plots the prediction with distributional uncertainty
Parameters
----------
mean : torch.Tensor
Point prediction
std : torch.Tensor
Uncertainty estimate
t_start: int, optional
Start of plotting window
t_end: int, optional
End of plotting window
ground_truth : torch.Tensor, optional
Ground truth of point prediction
fs : int, optional
Sampling rate. The default is 10
n_std : int, optional
Amount of standard deviations to plot. The default is 2
show_legend : bool, optional
Show legend in the plot. The default is False
title: string, optional
Title for the plot. The default is None
title_info: dict[string,string], optional
Must contain keys ['dataset', 'model', 'track', 'sensor']. The default is None.
fig_path : string, optional
Path to save fig. The default is None
is_duffing: bool, optional
Whether it is position prediction (like Duffing dataset) or acceleration prediction. The default is True.
show: bool, optional
If the plot should be displayed. The default is True.
Returns
-------
None.
"""
df = dataframe[t_start:t_end]
mean_val = np.array(df[[mean]])
var_val = np.array(df[[var]])
std = np.sqrt(var_val)
fig, ax = plt.subplots(1, 1)
if ground_truth is not None:
ax.plot(df.index, df[[ground_truth]], c='k', label=ground_truth)
ax.plot(df.index, mean_val, c='b', label=mean)
for i in range(1,n_std+1):
alpha = 0.1+0.4*(n_std+1-i)/(n_std+1)
ax.fill_between(df.index, (mean_val-i*std).squeeze(), (mean_val+i*std).squeeze(), alpha=alpha, color='b',
label=rf"Uncertainty {i}$\sigma$")
if show_legend:
ax.legend()
return fig, ax
[docs]
def plot_uncertainty_both(dataframe, mean, epistemic_var, aleatoric_var, ground_truth=None, t_start=None, t_end=None,
n_std=2, show_legend=False):
"""
Plots the prediction with aleatoric and epistemic uncertainty
Parameters
----------
mean : torch.Tensor
Point prediction
epistemic_std : torch.Tensor
Epistemic uncertainty estimate (e.g. ensemble, mcdo)
aleatoric_std : torch.Tensor
Aleatoric uncertainty estimate (e.g. mve)
t_start: int, optional
Start of plotting window
t_end: int, optional
End of plotting window
ground_truth : torch.Tensor, optional
Ground truth of point prediction
fs : int, optional
Sampling rate. The default is 10
n_std : int, optional
Amount of standard deviations to plot. The default is 2
show_legend : bool, optional
Show legend in the plot. The default is False
title: string, optional
Title for the plot. The default is None
title_info: dict[string,string], optional
Must contain keys ['dataset', 'model', 'track', 'sensor']. The default is None.
fig_path : string, optional
Path to save fig
Returns
-------
None.
"""
df = dataframe[t_start:t_end]
mean_val = np.array(df[[mean]])
ep_var = np.array(df[[epistemic_var]])
al_var = np.array(df[[aleatoric_var]])
ep_std = np.sqrt(ep_var)
al_std = np.sqrt(al_var)
fig, ax = plt.subplots(1, 1)
if ground_truth is not None:
ax.plot(df.index, df[[ground_truth]], c='k', label=ground_truth)
ax.plot(df.index, mean_val, c='b', label=mean)
std = ep_std
for i in range(1,n_std+1):
alpha = 0.1+0.4*(n_std+1-i)/(n_std+1)
ax.fill_between(df.index, (mean_val-i*std).squeeze(), (mean_val+i*std).squeeze(), alpha=alpha, color='r',
label=rf"Epistemic {i}$\sigma$")
std = ep_std + al_std
for i in range(1,n_std+1):
alpha = 0.1+0.4*(n_std+1-i)/(n_std+1)
ax.fill_between(df.index, (mean_val-i*std).squeeze(), (mean_val+i*std).squeeze(), alpha=alpha, color='b',
label=rf"Total {i}$\sigma$")
if show_legend:
ax.legend()
return fig, ax
[docs]
def plot_calibration_curve(dataframe, ground_truth, model_names, quantiles=np.arange(0.0, 1.05, 0.05), show_legend=False):
""" Runs the model prediction on track and plots the calibration at different quantile levels
Assumes that Quantile Regression models have QR in their name and that all other models predict mean and variance
Parameters
----------
models: list[uncertainty models]
Models to use for prediction
track: torch.Dataloader
Single track
track_number: int
Number of track in the test set
out_sens: list[string]
Names of output sensors
output: int, optional
Output sensor to plot. The default is 0.
quantiles: list[x], x in (0,1)
Quantile levels to analyze for calibration curve. The default is np.arange(0.0, 1.05, 0.05)
fig_path : string, optional
Path to save fig. The default is None
Returns
-------
None.
"""
fig, ax = plt.subplots(1, 1)
ax.plot(quantiles, quantiles, c="gray")
for name in model_names:
mean_val = torch.tensor(np.array(dataframe[[f'{ground_truth}_{name}']]))
var_val = torch.tensor(np.array(dataframe[[f'{ground_truth}_{name}_var']]))
targets = torch.tensor(np.array(dataframe[[ground_truth]]))
z_scores = [st.norm.ppf(1-(1-p)/2) for p in quantiles]
picp_scores = np.array([picp(mean_val, targets, var_val, z) for z in z_scores])
ax.plot(quantiles, picp_scores, marker='o', label=name)
ax.set_xlabel("Expected Confidence Level")
ax.set_ylabel("Observed Confidence Level")
if show_legend:
ax.legend()
return fig, ax
def _convert_to_latex(ch_names, ch_str, use_case):
""" Convert the channel names to LaTeX scientific notation if applicable."""
ch_dict = {
'Duffing_MVE': ['$F(t)$', '$x$'],
'Duffing_Enhanced': ['$F(t)$', '$x_{\\text{lin}}$', '$x$'],
'Duffing_Delta': ['$F(t)$', '$x_{\\text{lin}}$', '$\\Delta x$'],
'Duffing_Comparison': ['$F(t)$', '$x_{\\text{lin}}$', '$(\\Delta) x$'],
'Steering_MVE': ['$M2_x$', '$M2_y$', '$M2_z$', '$M3_x$', '$M3_y$', '$M3_z$', '$M8_x$', '$M8_y$', '$M8_z$'],
'Steering_Enhanced': ['$M2_x$', '$M2_y$', '$M2_z$', '$M3_x$', '$M3_y$', '$M3_z$', '$M8_{x,\\text{lin}}$',
'$M8_{y,\\text{lin}}$', '$M8_{z,\\text{lin}}$', '$M8_x$', '$M8_y$', '$M8_z$'],
'Steering_Delta': ['$M2_x$', '$M2_y$', '$M2_z$', '$M3_x$', '$M3_y$', '$M3_z$', '$M8_{x,\\text{lin}}$',
'$M8_{y,\\text{lin}}$', '$M8_{z,\\text{lin}}$', '$\\Delta M8_{x}$', '$\\Delta M8_{y}$',
'$\\Delta M8_{z}$'],
'Steering_Comparison': ['$M2_x$', '$M2_y$', '$M2_z$', '$M3_x$', '$M3_y$', '$M3_z$', '$M8_{x,\\text{lin}}$',
'$M8_{y,\\text{lin}}$', '$M8_{z,\\text{lin}}$', '$(\\Delta) M8_{x}$', '$(\\Delta) M8_{y}$',
'$(\\Delta) M8_{z}$'],
}
if ch_names:
ch_str = ch_dict[use_case] if use_case in ch_dict else ch_str
return ch_str
[docs]
def plot_sens_analysis(model, sensitivity, ch_names=None, title=None, use_case=None, annotations=False, orig_length=None, save_path=None):
"""
Visualize different metrics/insights of the post-processed arrays,
coming from differently averaged sensitivities.
Parameters
----------
model : Model consisting of nn.Modules
sensitivity : torch.Tensor
Sensitivity tensor that contains the gradients of the output with respect to the inputs
ch_names : list of strings, optional
List of channel names. The default is None, i.e. the channel names are not provided and
will be generated in the style of 'Inp_ch_i' and 'Rec_ch_i'.
title : string, optional
Title for the plot series. The default is None, i.e. no title is displayed.
annotations : bool, optional
If the plot should contain the percentages of the channel-wise mean sensitivities.
The default is False.
orig_length : int, optional
Original length of the time series. Only needed when sensitivity analysis is performed on
subset of the data. The default is None, i.e. the original length is not provided.
save_path : string, optional
Path to save the plots. The default is None, i.e. the plots are not saved.
"""
# Postprocess the sensitivity tensor
sum_mean_std_feature, sum_mean_std_inp_channels, out_ch_sens = _postprocess_sens(model, sensitivity)
rms_out_ch_sens = out_ch_sens[0]
if title is not None:
print(title)
else:
print(f'### Sensitivity analysis results for {model.__class__.__name__} model ###')
if save_path:
base_path = os.path.abspath(BASE_PATH)
use_case_str = use_case if use_case else ''
save_path = os.path.join(base_path, 'Evaluation_plots', use_case_str, save_path)
os.makedirs(save_path, exist_ok=True)
m_type = model.Type
input_channels = model.input_channels
pred_size = model.pred_size
num_timesteps, num_features = rms_out_ch_sens.shape
if m_type in ['AR', 'AR_RNN']:
win_size = max(model.window_size, model.rnn_window)
ch_size = input_channels + pred_size
rec_start_idx = input_channels*model.window_size
flatten_size = rec_start_idx + pred_size*model.rnn_window
ch_str = [f'Inp_ch_{i}' for i in range(input_channels)] + \
[f'Rec_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names
x_tks_feat = [i for i in np.arange(0, rec_start_idx, model.window_size)] + \
[i for i in np.arange(rec_start_idx, num_features+1, model.rnn_window)]
x_center = [i for i in np.arange(model.window_size//2, rec_start_idx, model.window_size)] + \
[i for i in np.arange(rec_start_idx+model.rnn_window//2, num_features+1, model.rnn_window)]
else:
win_size = model.window_size
ch_size = input_channels
rec_start_idx = input_channels*model.window_size
flatten_size = rec_start_idx
ch_str = [f'Inp_ch_{i}' for i in range(input_channels)] if ch_names is None else ch_names
x_tks_feat = [i for i in np.arange(0, num_features+1, model.window_size)]
# LaTeX scientific notation for channel names
new_ticks_pos = x_tks_feat[:-1] + (x_tks_feat[1] - x_tks_feat[0]) / 2
ch_str = _convert_to_latex(ch_names, ch_str, use_case)
out_str = [f'Out_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names[-pred_size:]
out_str = _convert_to_latex(ch_names, out_str, use_case)[-pred_size:]
angle = 65 if ch_size > 5 else 0
halign = 'left' # if ch_size < 5 else 'center'
font_size = 10 if ch_size < 5 else 8
ax_zero_label = r'RMS of Mean Sensitivities $[-]$'
ax_one_label = r'RMS of Std. Dev. of Sensitivities $[-]$'
# Plot the sum_mean_inp_channels, sum_std_inp_channels as heatmaps
sum_inp_channels, std_inp_channels = sum_mean_std_inp_channels # unpack tuple
fig, axs = plt.subplots(1, 2, figsize=(12,5))
if pred_size == 1:
categories = ch_str
mean_values, std_values = sum_inp_channels[0], std_inp_channels[0]
mean_norm = plt.Normalize(0, max(mean_values))
std_norm = plt.Normalize(0, max(std_values))
color_scheme = plt.get_cmap('Reds')
mean_colors = color_scheme(mean_norm(mean_values))
std_colors = color_scheme(std_norm(std_values))
axs[0].bar(categories, mean_values, color=mean_colors, width=0.5, alpha=0.8, edgecolor='black', linewidth=.75)
axs[1].bar(categories, std_values, color=std_colors, width=0.5, alpha=0.8, edgecolor='black', linewidth=.75)
axs[0].set_ylabel(ax_zero_label, labelpad=10)
axs[1].set_ylabel(ax_one_label, labelpad=10)
else:
pos0 = axs[0].imshow(sum_inp_channels, cmap='Reds', interpolation='none',
aspect='auto', extent=[0, ch_size, pred_size, 0], vmin=0)
cbar0 = fig.colorbar(pos0, ax=axs[0])
cbar0.formatter.set_powerlimits((0, 0))
cbar0.formatter.set_useMathText(True)
pos1 = axs[1].imshow(std_inp_channels, cmap='Reds', interpolation='none',
aspect='auto', extent=[0, ch_size, pred_size, 0], vmin=0)
cbar1 = fig.colorbar(pos1, ax=axs[1])
cbar1.formatter.set_powerlimits((0, 0))
cbar1.formatter.set_useMathText(True)
for ax in axs:
ax.set_xlabel('Input Channels', labelpad=10, fontsize=font_size)
ax.set_xticks(np.arange(ch_size))
ax.set_xticklabels(ch_str)
if pred_size > 1:
ax.set_ylabel('Output Channels', labelpad=10)
ax.set_xticklabels(ch_str, rotation=angle, ha=halign)
ax.set_yticks(np.arange(pred_size))
ax.set_yticklabels(out_str, rotation=angle, va='top', fontsize=font_size)
ax.grid(which='major', color='black', linestyle='-', linewidth=.5)
if m_type in ['AR', 'AR_RNN']:
ax.axvline(input_channels, color='black', linewidth=2)
else:
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)
ax.set_xlim(-0.5, len(categories)-0.5)
if save_path:
fig.tight_layout()
fig.subplots_adjust(wspace=0.25)
plt.savefig(os.path.join(save_path, 'sens_imshow.pdf'), bbox_inches='tight')
else:
plt.suptitle('Summed Sensitivities across Input & Output Channels', fontsize=13)
axs[0].set_title('RMS of Mean Sensitivities', pad=10)
axs[1].set_title('RMS of Std. Dev. of Sensitivities', pad=10)
fig.tight_layout()
fig.subplots_adjust(wspace=0.25)
plt.show()
# Plot the sum_mean_feature, sum_std_feature as bar plots
plt.rcParams.update({"axes.xmargin": 0})
sum_mean_feature, sum_std_feature = sum_mean_std_feature # unpack tuple
ch_mean = _reshape_array(model, sum_mean_feature, aggregation='mean', repeat=True)
ch_std_mean = _reshape_array(model, sum_std_feature, aggregation='mean', repeat=True)
means = _reshape_array(model, sum_mean_feature, aggregation='mean')
ch_median = _reshape_array(model, sum_mean_feature, aggregation='median', repeat=True)
ch_std_median = _reshape_array(model, sum_std_feature, aggregation='median', repeat=True)
fig, axs = plt.subplots(1, 2, figsize=(12,5))
edgecolor = 'black' if ch_size < 5 else 'None'
axs[0].bar(np.arange(num_features)+0.5, sum_mean_feature, width=0.9, alpha=0.8, edgecolor=edgecolor, linewidth=.5)
axs[1].bar(np.arange(num_features)+0.5, sum_std_feature, width=0.9, alpha=0.8, edgecolor=edgecolor, linewidth=.5)
# add a step plot for the mean value of each input/recurrent channel
axs[0].step(np.arange(num_features+1), ch_mean, lw=2, color='red')
axs[1].step(np.arange(num_features+1), ch_std_mean, lw=2, color='red')
axs[0].step(np.arange(num_features+1), ch_median, lw=2, color='green')
axs[1].step(np.arange(num_features+1), ch_std_median, lw=2, color='green')
# create manual legend
legend_entries = [Patch(edgecolor='black', label='Local' + '\n' + r'Sensitivity $s_{ik}$'),
Line2D([0], [0], color='red', label='Channel' + '\n' + r'Mean $\mu_k$'),
Line2D([0], [0], color='green', label='Channel' + '\n' + r'Median $P_{50, k}$')]
# plot annotations for the channel-wise percentages if flag is set
if annotations:
perc_mean = means / np.sum(means)
x_positions = [i+3 for i in x_tks_feat] if ch_size < 5 else x_center
y_positions = means*1.25 if ch_size < 5 else means*2.5
for x_pos, y_pos, perc in zip(x_positions, y_positions, perc_mean):
axs[0].text(x_pos, y_pos, f'{perc:.2f}', fontsize=11, ha=halign, va='bottom', color='red')
axs[0].set_ylabel(ax_zero_label, labelpad=10)
axs[1].set_ylabel(ax_one_label, labelpad=10)
for ax in axs:
ax.legend(handles=legend_entries, loc='upper left')
ax.set_xlabel('Input Channels', labelpad=15)
ax.set_xticks(x_tks_feat)
if ch_size < 5:
ax.set_xticklabels([])
for pos, label in zip(new_ticks_pos, ch_str):
ax.text(pos, ax.get_ylim()[0] - 0.06*ax.get_ylim()[1], label, ha='center', va='bottom')
else:
ax.set_xticklabels(ch_str + [''], rotation=angle, ha=halign, fontsize=font_size)
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)
ax.xaxis.grid(True)
if m_type in ['AR', 'AR_RNN']:
ax.axvline(rec_start_idx, color='black', linewidth=2)
if save_path:
fig.tight_layout()
fig.subplots_adjust(wspace=0.25)
plt.savefig(os.path.join(save_path, 'sens_barplot.pdf'), bbox_inches='tight')
else:
plt.suptitle('Sensitivities across all Channels, RMS over Output Channels', fontsize=13)
axs[0].set_title('RMS of Mean Sensitivities', pad=10)
axs[1].set_title('RMS of Std. Dev. of Sensitivities', pad=10)
fig.tight_layout()
fig.subplots_adjust(wspace=0.25)
plt.show()
plt.rcParams.update({"axes.xmargin": 0.01})
# Plot the averages across the time windows for each channel as line plots
rms_sum_out = _reshape_array(model, sum_mean_feature)
std_sum_out = _reshape_array(model, sum_std_feature)
rms_sum_inp, rms_sum_rec = np.split(rms_sum_out, [input_channels], axis=0)
std_sum_inp, std_sum_rec = np.split(std_sum_out, [input_channels], axis=0)
if use_case and ('Enhanced' in use_case or 'Delta' in use_case): # in case a linear pre-computed solution exists
rms_sum_inp, rms_sum_lin = np.split(rms_sum_inp, [input_channels-pred_size], axis=0)
std_sum_inp, std_sum_lin = np.split(std_sum_inp, [input_channels-pred_size], axis=0)
mean_rms_lin = np.mean(rms_sum_lin, axis=0)
mean_std_lin = np.mean(std_sum_lin, axis=0)
std_dev_rms_lin = np.std(rms_sum_lin, axis=0)
std_dev_std_lin = np.std(std_sum_lin, axis=0)
mean_rms_inp = np.mean(rms_sum_inp, axis=0)
mean_std_inp = np.mean(std_sum_inp, axis=0)
std_dev_rms_inp = np.std(rms_sum_inp, axis=0)
std_dev_std_inp = np.std(std_sum_inp, axis=0)
if m_type in ['AR', 'AR_RNN']:
mean_rms_rec = np.mean(rms_sum_rec, axis=0)
mean_std_rec = np.mean(std_sum_rec, axis=0)
std_dev_rms_rec = np.std(rms_sum_rec, axis=0)
std_dev_std_rec = np.std(std_sum_rec, axis=0)
fig, axs = plt.subplots(1, 2, figsize=(12,5))
alpha_val = 0.5 if (input_channels > 2 or pred_size > 2) else 1
lw = 1.5 if (input_channels > 2 or pred_size > 2) else 1.8
if ch_size <= 3:
for i in range(ch_size):
if i < input_channels:
axs[0].plot(rms_sum_out[i], lw=lw, alpha=alpha_val, label=ch_str[i])
axs[1].plot(std_sum_out[i], lw=lw, alpha=alpha_val, label=ch_str[i])
else:
axs[0].plot(rms_sum_out[i], linestyle='--', lw=lw, alpha=alpha_val, label=ch_str[i])
axs[1].plot(std_sum_out[i], linestyle='--', lw=lw, alpha=alpha_val, label=ch_str[i])
else:
axs[0].plot(mean_rms_inp, lw=2, color='darkred', label=r'$\mu_\text{inp}$')
axs[1].plot(mean_std_inp, lw=2, color='darkred', label=r'$\mu_\text{inp}$')
axs[0].fill_between(np.arange(mean_rms_inp.shape[0]), mean_rms_inp-std_dev_rms_inp,
mean_rms_inp+std_dev_rms_inp, color='red', alpha=0.3, label=r'$\sigma_\text{inp}$')
axs[1].fill_between(np.arange(mean_std_inp.shape[0]), mean_std_inp-std_dev_std_inp,
mean_std_inp+std_dev_std_inp, color='red', alpha=0.3, label=r'$\sigma_\text{inp}$')
if use_case and ('Enhanced' in use_case or 'Delta' in use_case):
axs[0].plot(mean_rms_lin, lw=2, color='green', label=r'$\mu_\text{lin}$')
axs[1].plot(mean_std_lin, lw=2, color='green', label=r'$\mu_\text{lin}$')
axs[0].fill_between(np.arange(mean_rms_lin.shape[0]), mean_rms_lin-std_dev_rms_lin,
mean_rms_lin+std_dev_rms_lin, color='palegreen', alpha=0.4, label=r'$\sigma_\text{lin}$')
axs[1].fill_between(np.arange(mean_std_lin.shape[0]), mean_std_lin-std_dev_std_lin,
mean_std_lin+std_dev_std_lin, color='palegreen', alpha=0.4, label=r'$\sigma_\text{lin}$')
if m_type in ['AR', 'AR_RNN']:
axs[0].plot(mean_rms_rec, lw=2, color='blue', linestyle='--', label=r'$\mu_\text{rec}$')
axs[1].plot(mean_std_rec, lw=2, color='blue', linestyle='--', label=r'$\mu_\text{rec}$')
axs[0].fill_between(np.arange(mean_rms_rec.shape[0]), mean_rms_rec-std_dev_rms_rec,
mean_rms_rec+std_dev_rms_rec, color='cornflowerblue', alpha=0.4, label=r'$\sigma_\text{rec}$')
axs[1].fill_between(np.arange(mean_std_rec.shape[0]), mean_std_rec-std_dev_std_rec,
mean_std_rec+std_dev_std_rec, color='cornflowerblue', alpha=0.4, label=r'$\sigma_\text{rec}$')
time_deltas = {
10: 2,
20: 5,
25: 5,
30: 5,
40: 10,
50: 10,
75: 15,
100: 20,
125: 25,
150: 30,
175: 35,
200: 40
}
time_delta = time_deltas[win_size]
time_str = [f't-{i}' for i in np.arange(win_size, -1, -time_delta)][:-1]
time_str.append('t')
axs[0].set_ylabel(ax_zero_label, labelpad=10)
axs[1].set_ylabel(ax_one_label, labelpad=10)
for ax in axs:
ax.set_xlabel(r'Sliding Time Window $[-]$', labelpad=10)
ax.set_xticks(np.arange(win_size+1, step=time_delta))
ax.set_xticklabels(time_str)
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)
if use_case and (use_case == 'Duffing_Enhanced' or use_case == 'Duffing_Delta'):
ax.get_lines()[0].set_color('blue')
ax.get_lines()[1].set_color('green')
ax.get_lines()[2].set_color('purple')
elif use_case and use_case == 'Duffing_MVE':
ax.get_lines()[0].set_color('blue')
ax.get_lines()[1].set_color('purple')
if ch_size < 5:
ax.legend(title='Channels', ncol=1)
else:
handles, labels = ax.get_legend_handles_labels()
new_handles = [handles[i] for i in range(0, len(handles), 2)] + [handles[i] for i in range(1, len(handles), 2)]
new_labels = [labels[i] for i in range(0, len(labels), 2)] + [labels[i] for i in range(1, len(labels), 2)]
ax.legend(new_handles, new_labels, ncol=2, loc='upper left')
if save_path:
fig.tight_layout()
fig.subplots_adjust(wspace=0.25)
plt.savefig(os.path.join(save_path, 'sens_sliding_time_window.pdf'), bbox_inches='tight')
else:
plt.suptitle('Sensitivities across Time Window for all Channels, RMS over Output Channels', fontsize=13)
axs[0].set_title('RMS of Mean Sensitivities', pad=10)
axs[1].set_title('RMS of Std of Sensitivities', pad=10)
fig.tight_layout()
fig.subplots_adjust(wspace=0.25)
plt.show()
# Plot the temporal development for one selected channel as line plots
ch_idx = ch_size-1 # to be chosen between 0 and ch_size-1
fig, ax = plt.subplots(figsize=(9,5))
# Define a custom colormap from light blue to light green
cmap = LinearSegmentedColormap.from_list('blue_to_green', ['blue', 'green'])
# Randomly select 50 time steps (if applicable) and sort them to avoid aliasing effects
if num_timesteps > 50:
samples = LatinHypercube(d=1).random(n=50).flatten()
random_indices = np.sort(np.floor(samples * num_timesteps)).astype(int)
else:
random_indices = np.arange(num_timesteps)
temp_dev_ch = _reshape_array(model, rms_out_ch_sens, remove_nans=True)
temp_dev_ch = np.stack([np.append(np.nan, x[ch_idx]) for x in temp_dev_ch])
mean = np.mean(temp_dev_ch[random_indices], axis=0)
std_dev = np.std(temp_dev_ch[random_indices], axis=0)
for i in random_indices:
color = cmap(i / num_timesteps)
if i == random_indices[0]:
ax.plot(temp_dev_ch[i], alpha=0.5, color=color, label='Time Values')
else:
ax.plot(temp_dev_ch[i], alpha=0.2, color=color)
ax.plot(mean, color='black', lw=2, label=r'Mean $\mu$')
ax.fill_between(np.arange(mean.shape[0]), mean-std_dev, mean+std_dev,
color='red', alpha=0.5, label=r'Std. Dev. $\pm\sigma$')
wind_size = model.window_size if ch_idx < input_channels else model.rnn_window
time_delta = time_deltas[wind_size]
time_str = [f't-{i}' for i in np.arange(wind_size, -1, -time_delta)][:-1]
time_str.append('t')
ax.set_xlabel(r'Sliding Time Window $[-]$', labelpad=10)
ax.set_ylabel(ax_zero_label, labelpad=10)
ax.set_xticks(np.arange(wind_size+1, step=time_delta))
ax.set_xticklabels(time_str)
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)
ax.legend(loc='upper left')
if save_path:
fig.tight_layout()
plt.savefig(os.path.join(save_path, f'sens_window_ch_{ch_idx+1}.pdf'), bbox_inches='tight')
else:
ch_name = 'Input' if ch_idx < input_channels else 'Recurrent'
ax.set_title(f'Temporal Development across Time Window for {ch_name} Channel "{ch_str[ch_idx]}"', pad=10)
fig.tight_layout()
plt.show()
# Plot the temporal development for all channels as line plots
# 'mean' for correlation with output signal, 'rms' for effective-valued aggregation
temp_seasonal = _reshape_array(model, rms_out_ch_sens, aggregation='rms').transpose()
x_length = num_timesteps/10 if not orig_length else orig_length
fig, ax = plt.subplots(figsize=(9,5))
filter_size = win_size//2 if flatten_size < 5*win_size else win_size
for i in range(ch_size):
# compute the running average for each channel
if flatten_size < 4*win_size:
temp_season = temp_seasonal[i]
else:
temp_season = np.convolve(temp_seasonal[i], np.ones(filter_size)/filter_size, mode='valid')
temp_season = np.concatenate([np.full((filter_size-1)//2, np.nan), temp_season])
if i < input_channels:
if i == 0:
ax.plot(temp_season, alpha=0.7, label=ch_str[i])
else:
ax.plot(temp_season, lw=2, alpha=alpha_val, label=ch_str[i])
else:
ax.plot(temp_season, linestyle='--', lw=2, alpha=alpha_val, label=ch_str[i])
ax.set_xlabel(r'Time $[s]$', labelpad=10)
ax.set_ylabel(ax_zero_label, labelpad=10)
ax.set_xticks(np.arange(num_timesteps+1, step=num_timesteps//5))
ax.set_xticklabels(np.arange(x_length+1, step=x_length//5, dtype=int))
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)
if use_case and (use_case == 'Duffing_Enhanced' or use_case == 'Duffing_Delta'):
ax.get_lines()[0].set_color('blue')
ax.get_lines()[1].set_color('green')
ax.get_lines()[2].set_color('purple')
elif use_case and use_case == 'Duffing_MVE':
ax.get_lines()[0].set_color('blue')
ax.get_lines()[1].set_color('purple')
if ch_size < 5:
ax.legend(title='Channels', loc='upper left')
else:
ax.legend(title='Channels', loc='upper left', bbox_to_anchor=(1,1))
if save_path:
plt.savefig(os.path.join(save_path, 'sens_seasonalities.pdf'), bbox_inches='tight')
fig.tight_layout()
else:
ax.set_title('Global Development of Channel Sensitivities, RMS over Output Channels', pad=10)
fig.tight_layout()
plt.show()
def _fft_signal(signal, sampling_rate=10):
N = len(signal)
magnitudes = np.abs(np.fft.fft(signal))
frequencies = np.fft.fftfreq(N, 1/sampling_rate)
dominant_freq = frequencies[np.argmax(magnitudes[:N//2])]
return dominant_freq, frequencies[:N//2], magnitudes[:N//2]
# Plot the frequency spectra of the sensitivities for each channel (DUFFING oscillator only)
if use_case and 'Duffing' in use_case:
fig, ax = plt.subplots(figsize=(9,5))
for i in range(ch_size):
_, freqs, mags = _fft_signal(temp_seasonal[i])
ax.plot(freqs, mags, alpha=0.8, label=f'{ch_str[i]}') #: {dominant_freq:.1f} Hz')
ax.get_lines()[0].set_alpha(0.5)
if use_case == 'Duffing_Enhanced' or use_case == 'Duffing_Delta':
ax.get_lines()[0].set_color('blue')
ax.get_lines()[1].set_color('green')
ax.get_lines()[2].set_color('purple')
elif use_case == 'Duffing_MVE':
ax.get_lines()[0].set_color('blue')
ax.get_lines()[1].set_color('purple')
ax.set_yscale('log')
ax.set_xlabel(r'Frequency $[Hz]$', labelpad=10)
ax.set_ylabel(r'Magnitude $[-]$', labelpad=10)
ax.legend(title='Channels', loc='upper right')
if save_path:
plt.savefig(os.path.join(save_path, 'sens_ffts.pdf'), bbox_inches='tight')
fig.tight_layout()
else:
ax.set_title('Frequency Spectra of Channel Sensitivities', pad=10)
fig.tight_layout()
plt.show()
# Plot the temporal fluctuations of the sensitivities
plt.rcParams.update({"axes.xmargin": 0})
fig, ax = plt.subplots(figsize=(9,5))
sum_out_ch_sensi = np.concatenate((np.full((num_timesteps, 1), np.nan), rms_out_ch_sens), axis=1)
step_size = 2 if flatten_size > 100 else 1
for i in random_indices:
color = cmap(i / num_timesteps)
if i == random_indices[0]:
ax.plot(sum_out_ch_sensi[i, ::step_size], alpha=0.5, color=color, label='Time Values')
else:
ax.plot(sum_out_ch_sensi[i, ::step_size], alpha=0.2, color=color)
mean = np.mean(sum_out_ch_sensi[random_indices, ::step_size], axis=0)
std = np.std(sum_out_ch_sensi[random_indices, ::step_size], axis=0)
ax.plot(mean, lw=2, color='black', label=r'Mean $\mu$')
ax.fill_between(np.arange(mean.shape[0]), mean-std, mean+std, color='red', alpha=0.5, label=r'Std. Dev. $\pm\sigma$')
temp_dev_rms_out = _reshape_array(model, rms_out_ch_sens, aggregation='mean')
means = np.mean(temp_dev_rms_out[random_indices], axis=0)
percentages = means / np.sum(means)
temp1 = means[:model.input_channels].repeat(model.window_size)
if m_type in ['AR', 'AR_RNN']:
temp2 = means[model.input_channels:].repeat(model.rnn_window)
mean_steps = np.hstack((np.append(np.nan, temp1), temp2))[::step_size]
else:
mean_steps = np.append(np.nan, temp1)[::step_size]
ax.step(np.arange(len(mean_steps)), mean_steps, lw=2, color='red', label='Channel Mean')
# plot annotations for the channel-wise percentages if flag is set
if annotations:
x_positions = [i/step_size+5 for i in x_tks_feat] if ch_size < 5 else [i/step_size for i in x_center]
y_positions = means*1.25 if ch_size < 5 else means*2.5
for x_pos, y_pos, perc in zip(x_positions, y_positions, percentages):
ax.text(x_pos, y_pos, f'{perc:.2f}', fontsize=11, ha=halign, va='bottom', color='red')
if m_type in ['AR', 'AR_RNN']:
ax.axvline(rec_start_idx/step_size, color='black', linewidth=2)
ax.set_xlabel('Input Channels', labelpad=18)
ax.set_ylabel(r'RMS of Sensitivities $[-]$', labelpad=10)
ax.set_xticks([i/step_size for i in x_tks_feat])
if ch_size < 5:
ax.set_xticklabels([])
for pos, label in zip(new_ticks_pos, ch_str):
ax.text(pos, ax.get_ylim()[0] - 0.06*ax.get_ylim()[1], label, ha='center', va='bottom')
else:
ax.set_xticklabels(ch_str + [''], rotation=angle, ha=halign, fontsize=font_size)
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)
ax.xaxis.grid(True)
ax.legend(loc='upper left')
if save_path:
fig.tight_layout()
plt.savefig(os.path.join(save_path, 'sens_temp_dev.pdf'), bbox_inches='tight')
else:
ax.set_title('Temporal Fluctuations of Sensitivities, RMS over Output Channels', pad=10)
fig.tight_layout()
plt.show()
plt.rcParams.update({"axes.xmargin": 0.01})
[docs]
def plot_feature_importance(model, sensitivity, time_step, ch_names=None, use_case=None, save_path=None):
"""
Visualize the feature importance of the sensitivity tensor for a given model, at a specific time step.
"""
if save_path:
base_path = os.path.abspath(BASE_PATH)
use_case_str = use_case if use_case else ''
save_path = os.path.join(base_path, 'Evaluation_plots', use_case_str, save_path)
os.makedirs(save_path, exist_ok=True)
# Postprocess the sensitivity tensor
sens_str = 'Mean' if (model.Pred_Type == 'Mean_Var' or model.Ensemble) else 'Point'
if isinstance(sensitivity[sens_str], list):
sensitivity = torch.stack(sensitivity[sens_str]).mean(dim=0)
else:
sensitivity = sensitivity[sens_str]
mean_out_ch_sens = _postprocess_sens(model, sensitivity)[2][1]
m_type = model.Type
input_channels = model.input_channels
pred_size = model.pred_size
num_features = mean_out_ch_sens.shape[1]
if m_type in ['AR', 'AR_RNN']:
ch_size = input_channels + pred_size
rec_start_idx = input_channels*model.window_size
ch_str = [f'Inp_ch_{i}' for i in range(input_channels)] + \
[f'Rec_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names
x_tks_feat = [i for i in np.arange(0, rec_start_idx, model.window_size)] + \
[i for i in np.arange(rec_start_idx, num_features+1, model.rnn_window)]
else:
ch_size = input_channels
ch_str = [f'Inp_ch_{i}' for i in range(input_channels)] if ch_names is None else ch_names
x_tks_feat = [i for i in np.arange(0, num_features+1, model.window_size)]
# LaTeX scientific notation for channel names
ch_str = _convert_to_latex(ch_names, ch_str, use_case)
out_str = [f'Out_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names[-pred_size:]
out_str = _convert_to_latex(ch_names, out_str, use_case)[-pred_size:]
angle = 65 if ch_size > 5 else 0
halign = 'left' # if ch_size < 5 else 'center'
font_size = 10 if ch_size < 5 else 8
# Plot the mean_feature as bar plots for one specific time step
plt.rcParams.update({"axes.xmargin": 0})
time_steps = [time_step] if isinstance(time_step, int) else time_step
fig, ax = plt.subplots(figsize=(9,5))
# alpha = 0.8 if len(time_steps) == 1 else 0.5
colors = ['#1f77b4', '#2ca02c', '#9467bd', '#8c564b'] # i.e. blue, green, purple, brown
alphas = [0.9, 0.65, 0.5, 0.4] if len(time_steps) != 1 else [0.8]
for step, color, alpha in (zip(time_steps, colors, alphas)):
feat_imp = mean_out_ch_sens[step]
ch_mean = _reshape_array(model, feat_imp, aggregation='mean', repeat=True)
ax.bar(np.arange(num_features)+0.5, feat_imp, width=1, alpha=alpha, color=color, edgecolor='black', linewidth=.5)
if len(time_steps) == 1:
ax.step(np.arange(num_features+1), ch_mean, lw=2, color='red')
# create manual legend
legend_entries = []
for step, color in zip(time_steps, colors):
legend_entries.append(Patch(facecolor=color, edgecolor='black', label='Local Sensitivities' + '\n' + rf' $s_{{ik}}\vert_{{t={step}}}$'))
if len(time_steps) == 1:
legend_entries.append(Line2D([0], [0], color='red', label='Channel Mean' + '\n' + rf'Sensitivities $\mu_k\vert_{{t={time_steps[0]}}}$'))
ax.set_ylabel(r'Mean Sensitivities $[-]$', labelpad=10)
ax.legend(handles=legend_entries, loc='best')
ax.set_xlabel('Input Channels', labelpad=10)
ax.set_xticks(x_tks_feat)
ax.set_xticklabels(ch_str + [''], rotation=angle, ha=halign, fontsize=font_size)
ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)
ax.xaxis.grid(True)
if m_type in ['AR', 'AR_RNN']:
ax.axvline(rec_start_idx, color='black', linewidth=2)
if save_path:
fig.tight_layout()
fig.subplots_adjust(wspace=0.2)
plt.savefig(os.path.join(save_path, f'sens_barplot_timestep_{time_step}.pdf'), bbox_inches='tight')
else:
ax.set_title(f'Sensitivities across all Channels for Time Step {time_step}', pad=10)
fig.tight_layout()
fig.subplots_adjust(wspace=0.2)
plt.show()
plt.rcParams.update({"axes.xmargin": 0.01})
[docs]
def plot_uncertainty_sens(model, sens_uq, eps, ch_names, use_case=None, amplification=1, save_path=None):
"""
Visualize the uncertainty of the sensitivity tensor for a given model.
Parameters
----------
model : Model consisting of nn.Modules
sens_uq : torch.Tensor
Sensitivity tensor that contains the gradients of the output with respect to the inputs
eps : float
Epsilon value for the uncertainty quantification
ch_names : list of strings
List of channel names
use_case : string, optional
Use case for the channel names. The default is use_case.
amplification : int, optional
Amplification factor for the uncertainty quantification of sensitivities of MVE models. The default is 1.
save_path : string, optional
Path to save the plots. The default is None, i.e. the plots are not saved.
"""
if save_path:
base_path = os.path.abspath(BASE_PATH)
use_case_str = use_case if use_case else ''
save_path = os.path.join(base_path, 'Evaluation_plots', use_case_str, save_path)
os.makedirs(save_path, exist_ok=True)
# Postprocess the sensitivity tensor
sens_uq, eps = sens_uq.numpy(), eps.numpy()
sum_win_size_sens = _reshape_array(model, sens_uq, aggregation='rms') # shape [sens_length*random_samples, pred_size, input_channels+pred_size]
pred_size = model.pred_size
ch_size = model.input_channels + pred_size
ch_str = [f'Inp_ch_{i}' for i in range(model.input_channels)] + \
[f'Rec_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names
ch_str = _convert_to_latex(ch_names, ch_str, use_case)
out_str = [f'Out_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names[-pred_size:]
out_str = _convert_to_latex(ch_names, out_str, use_case)[-pred_size:]
colors = ['blue', 'purple'] if (use_case and use_case == 'Duffing_MVE') else ['blue', 'green', 'purple']
# scale the step_size appropriately, such that always 1000 data points are plotted
step_size = max(1, eps.shape[0]//1000)
print('Shapes:', eps[::step_size,:].shape, sum_win_size_sens[::step_size,...].shape)
# Formulate fitting function for parabola
def fit_func(x, a, b, c):
return a*x**2 + b*x + c
def compute_uncertainty(x, y, num_sects):
x_sections = np.array_split(np.sort(x), num_sects)
stds = []
for section in x_sections:
sect_idc = np.where((x >= section.min()) & (x < section.max()))
stds.append(np.std(y[sect_idc]))
return np.array(stds)
# Plot the uncertainty of the sensitivity tensor
x_lim = 3 * amplification
fig, axs = plt.subplots(1, pred_size, figsize=(12,6))
axs = [axs] if pred_size == 1 else axs
for i, ax in enumerate(axs):
x_ = eps[::step_size,i]
x = x_[(x_ > -x_lim) & (x_ < x_lim)] # remove outliers outside the 3-sigma confidence interval
x_fit = np.linspace(x.min(), x.max(), 20)
divider = make_axes_locatable(ax)
ax_histx = divider.append_axes("top", size=1.05, pad=0.2, sharex=ax)
ax_histy = divider.append_axes("right", size=1.05, pad=0.2, sharey=ax)
ax_histx.hist(x, bins=30, density=True, alpha=0.5, color='red', orientation='vertical')
if amplification > 1:
ax.axvspan(-3, 3, color='gray', alpha=0.2)
for j, color in zip(range(ch_size), colors):
y = sum_win_size_sens[::step_size,i,j]
y = y[(x_ > -x_lim) & (x_ < x_lim)]
ax.scatter(x, y, s=15, alpha=0.3, edgecolors='black', linewidth=1, color=color, label=f'{ch_str[j]}')
ax_histy.hist(y, bins=10, density=True, weights=np.sqrt(y), alpha=0.5, color=color, orientation='horizontal')
# fit a quadratic function to the data, weighted by the square root of the output sensitivity
popt, _ = curve_fit(fit_func, x, y, sigma=1.0/np.sqrt(y))
print(f'Fitted parabola for Channel {j+1}: {popt[0]:.3f}*x^2 + {popt[1]:.3f}*x + {popt[2]:.3f}')
ax.plot(x_fit, fit_func(x_fit, *popt), color=color, linestyle='-', lw=2)
stds = compute_uncertainty(x, y, len(x_fit))
ax.fill_between(x_fit, fit_func(x_fit, *popt)-stds, fit_func(x_fit, *popt)+stds, color=color, alpha=0.4)
ax_histx.xaxis.set_tick_params(labelbottom=False)
ax_histy.yaxis.set_tick_params(labelleft=False)
ax.set_xlim(-x_lim, x_lim)
ax.set_xlabel(r'$\frac{x-\mu}{\sigma}$ $[-]$', labelpad=5)
ax.set_ylabel(r'Output Sensitivity $[-]$', labelpad=10)
ax.legend(loc='upper left', bbox_to_anchor=(1.0125,0,1,1.29))
if save_path:
fig.tight_layout()
plt.savefig(os.path.join(save_path, 'uncertainty_sens.pdf'), bbox_inches='tight')
else:
plt.suptitle('Sensitivity Range Across the Aleatoric Uncertainty in the Predicted Output', fontsize=13)
for i, ax in enumerate(axs):
ax_histx.set_title(f'Recurrent Channel {out_str[i]}', pad=10)
fig.tight_layout()
plt.show()
[docs]
def plot_grad_sens_comparison(vals_dict, models, sum_mean_std_list, ch_names=None, use_case=None, scientific_notation=False, save_path=None):
"""
Visualize the comparison of the sensitivity tensors of different models
for the same prediction type.
Parameters
----------
vals_dict: dict
Dict, containing the mode to compare, e.g. alpha, excitation amplitude/frequency;
as well as the values to compare, e.g. [0.1, 0.2, 0.3]
models : list[Model consisting of nn.Modules]
Models consisting of nn.Modules
sum_mean_std_list : list of torch.Tensor
List of sensitivity tensors that contain the gradients of the output with respect to the inputs
ch_names : list of strings, optional
List of channel names. The default is None, i.e. the channel names are not provided and
will be generated in the style of 'Inp_ch_i' and 'Rec_ch_i'.
scientific_notation : bool, optional
If the plot legends should be displayed in LaTeX scientific notation. The default is False.
save_path : string, optional
Path to save the plots. The default is None, i.e. the plots are not saved.
"""
if save_path:
folder, file = os.path.split(save_path)
save_path = os.path.join(os.getcwd(), 'Evaluation_plots', folder)
os.makedirs(save_path, exist_ok=True)
win_size = min([max(model.window_size, model.rnn_window) for model in models])
model = models[0]
same_ch_sizes = all([(model.input_channels == m.input_channels) for m in models])
# find the model with the maximum number of input channels
for m in models:
if m.input_channels > model.input_channels:
diff = abs(m.input_channels - model.input_channels)
model = m
max_inp_chs = model.input_channels
num_chs = model.input_channels + model.pred_size
num_features = win_size * num_chs
ch_str = [f'Inp_ch_{i}' for i in range(model.input_channels)] + \
[f'Rec_ch_{i}' for i in range(model.pred_size)] if ch_names is None else ch_names
ch_str = _convert_to_latex(ch_names, ch_str, use_case)
ax_zero_label = r'RMS of Sensitivities $[-]$'
ax_one_label = r'RMS of Std. Dev. of Sensitivities $[-]$'
x_tks_feat = [i for i in np.arange(num_features+1, step=win_size)]
if scientific_notation:
colors = ['blue', 'green', 'red', 'purple', 'purple', 'cyan']
alphas = [1.0 for _ in range(len(models))]
patterns = ['' for _ in range(len(models))]
else:
colors = ['blue', 'green', 'red', 'blue', 'green', 'red']
alphas = [1.0 if i < 3 else 0.5 for i in range(len(models))]
patterns = ['', '', '', '//', '\\\\', '//']
# Plot the sum_mean_feature, sum_std_feature as bar plots
fig, axs = plt.subplots(1, 2, figsize=(15,6))
bar_width = 0.1 if len(models) > 3 else 0.25
bar1 = np.arange(num_chs)
bars = [[x + i*bar_width for x in bar1] for i in range(1, len(models)+1)]
bars.insert(0, bar1)
for model, x, br, color, alpha, pat, value in zip(models, sum_mean_std_list, bars, colors, alphas, patterns, vals_dict['values']):
sum_mean_feature, sum_std_feature = x # unpack tuple
if num_chs <= 3:
rms_sum_out_inp_ch = _reshape_array(model, sum_mean_feature, aggregation='mean')
std_sum_out_inp_ch = _reshape_array(model, sum_std_feature, aggregation='mean')
if not same_ch_sizes and (model.input_channels < max_inp_chs):
insert = np.full(diff, np.nan)
rms_sum_out_inp_ch = np.insert(rms_sum_out_inp_ch, model.input_channels, insert)
std_sum_out_inp_ch = np.insert(std_sum_out_inp_ch, model.input_channels, insert)
axs[0].bar(br, rms_sum_out_inp_ch, width=bar_width, color=color, edgecolor='black', alpha=alpha, hatch=pat, label=f'{value}')
axs[1].bar(br, std_sum_out_inp_ch, width=bar_width, color=color, edgecolor='black', alpha=alpha, hatch=pat, label=f'{value}')
else:
rms_sum_out_inp_ch = _reshape_array(model, sum_mean_feature, aggregation='mean',
repeat=True, repeat_size=win_size)
std_sum_out_inp_ch = _reshape_array(model, sum_std_feature, aggregation='mean',
repeat=True, repeat_size=win_size)
if not same_ch_sizes and (model.input_channels < max_inp_chs):
rec_start_idx = model.input_channels*win_size + 1
insert = np.full(diff, np.nan).repeat(win_size)
rms_sum_out_inp_ch = np.insert(rms_sum_out_inp_ch, rec_start_idx, insert)
std_sum_out_inp_ch = np.insert(std_sum_out_inp_ch, rec_start_idx, insert)
axs[0].step(np.arange(num_features+1), rms_sum_out_inp_ch, alpha=0.8, color=color, label=f'{value}')
axs[1].step(np.arange(num_features+1), std_sum_out_inp_ch, alpha=0.8, color=color, label=f'{value}')
angle = 65 if num_chs > 3 else 0
halign = 'center' if num_chs < 5 else 'left'
axs[0].set_ylabel(ax_zero_label, labelpad=10)
axs[1].set_ylabel(ax_one_label, labelpad=10)
for ax in axs:
ax.set_xlabel('Channels', labelpad=10)
if num_chs <= 3:
ax.set_xticks([x + (len(models)/2 - 0.5) * bar_width for x in bar1])
ax.set_xticklabels(ch_str, rotation=angle, ha=halign)
else:
ax.xaxis.grid(True)
ax.set_xticks(x_tks_feat)
ax.set_xticklabels(ch_str + [''], rotation=angle, ha=halign)
ax.axvline(model.input_channels*win_size, color='black', linewidth=2)
if scientific_notation:
leg = ax.legend(title=rf'$\{vals_dict["mode"]}$', loc='upper left')
else:
leg = ax.legend(title=f'{vals_dict["mode"]}', loc='upper left')
leg.get_title().set_bbox(dict(facecolor='none', edgecolor='none', pad=20))
if save_path:
fig.tight_layout()
fig.subplots_adjust(wspace=0.2)
plt.savefig(os.path.join(save_path, f'{file}_comp_bar_plot.pdf'), bbox_inches='tight')
else:
plt.suptitle('Sensitivities across Input Features, RMS over Output Channels', fontsize=13)
axs[0].set_title('RMS of Mean Sensitivities', pad=10)
axs[1].set_title('RMS of Std of Sensitivities', pad=10)
fig.tight_layout()
fig.subplots_adjust(wspace=0.2)
plt.show()
# Plot the averages across the time windows for each channel as line plots
fig, axs = plt.subplots(1, 2, figsize=(15,6))
if scientific_notation:
alphas = [1.0 if (i+1) % 2 == 0 else 0.3 for i in range(len(models))]
else:
alphas = [1.0 for _ in range(len(models))]
for model, x, color, alpha, value in zip(models, sum_mean_std_list, colors, alphas, vals_dict['values']):
sum_mean_feature, sum_std_feature = x # unpack tuple
rms_sum_out = _reshape_array(model, sum_mean_feature)
std_sum_out = _reshape_array(model, sum_std_feature)
mean_rms_inp = np.mean(rms_sum_out[:model.input_channels,:], axis=0)
mean_rms_rec = np.mean(rms_sum_out[model.input_channels:,:], axis=0)
mean_std_inp = np.mean(std_sum_out[:model.input_channels,:], axis=0)
mean_std_rec = np.mean(std_sum_out[model.input_channels:,:], axis=0)
if scientific_notation:
axs[0].plot(np.concatenate([[np.nan], mean_rms_inp[-win_size:]]), color=color,
alpha=alpha, lw=1.75)
axs[0].plot(np.concatenate([[np.nan], mean_rms_rec[-win_size:]]), color=color,
alpha=alpha, lw=1.75, linestyle='-.')
axs[1].plot(np.concatenate([[np.nan], mean_std_inp[-win_size:]]), color=color,
alpha=alpha, lw=1.75)
axs[1].plot(np.concatenate([[np.nan], mean_std_rec[-win_size:]]), color=color,
alpha=alpha, lw=1.75, linestyle='-.')
else:
axs[0].plot(np.concatenate([[np.nan], mean_rms_inp[-win_size:]]), color=color,
alpha=0.5, lw=1.75)
axs[0].plot(np.concatenate([[np.nan], mean_rms_rec[-win_size:]]), color=color,
alpha=1.0, lw=1.75, linestyle='-.')
axs[1].plot(np.concatenate([[np.nan], mean_std_inp[-win_size:]]), color=color,
alpha=0.5, lw=1.75)
axs[1].plot(np.concatenate([[np.nan], mean_std_rec[-win_size:]]), color=color,
alpha=1.0, lw=1.75, linestyle='-.')
time_deltas = {
10: 2,
20: 5,
25: 5,
30: 5,
40: 10,
50: 10,
75: 15,
100: 20,
125: 25,
150: 30,
175: 35,
200: 40
}
time_delta = time_deltas[win_size]
time_str = [f't-{i}' for i in np.arange(win_size, -1, -time_delta)][:-1]
time_str.append('t')
if scientific_notation:
legend_entries = [Line2D([0], [0], color=color, label=value) \
for color, value in zip(colors, vals_dict['values'])]
else:
legend_entries = [Line2D([0], [0], color=color, label=value, linestyle='-') \
for color, value in zip(colors, vals_dict['values'])]
if 'Comparison' in use_case:
legend_anchor = (0.275, 1.0)
elif 'Steering' in use_case:
legend_anchor = (0.21, 1.0)
else:
legend_anchor = (0.213, 1.0)
axs[0].set_ylabel(ax_zero_label, labelpad=10)
axs[1].set_ylabel(ax_one_label, labelpad=10)
for ax in axs:
ax.set_xlabel(r'Sliding Time Window $[-]$', labelpad=10)
ax.set_xticks(np.arange(win_size+1, step=time_delta))
ax.set_xticklabels(time_str)
if scientific_notation:
leg_one = ax.legend(handles=legend_entries, title=rf'$\{vals_dict["mode"]}$', loc='upper left')
leg_two = ax.legend(handles=[Line2D([0], [0], color='gray', label='Input'),
Line2D([0], [0], color='gray', linestyle='-.', label='Recurrent')],
title='Channels', loc='upper left', bbox_to_anchor=(0.15, 1.0))
else:
leg_one = ax.legend(handles=legend_entries, title=vals_dict['mode'], loc='upper left')
leg_two = ax.legend(handles=[Line2D([0], [0], color='gray', label='Input'),
Line2D([0], [0], color='gray', linestyle='-.', label='Recurrent')],
title='Channels', loc='upper left', bbox_to_anchor=legend_anchor)
leg_one.get_title().set_bbox(dict(facecolor='none', edgecolor='none', pad=20))
leg_two.get_title().set_bbox(dict(facecolor='none', edgecolor='none', pad=20))
ax.add_artist(leg_one)
ax.add_artist(leg_two)
if save_path:
fig.tight_layout()
fig.subplots_adjust(wspace=0.2)
plt.savefig(os.path.join(save_path, f'{file}_comp_sliding_time_window.pdf'), bbox_inches='tight')
else:
plt.suptitle('Sensitivities across Time Window for all Input Channels, RMS over Output Channels', fontsize=13)
axs[0].set_title('RMS of Mean Sensitivities', pad=10)
axs[1].set_title('Sum of Std of Sensitivities', pad=10)
fig.tight_layout()
fig.subplots_adjust(wspace=0.2)
plt.show()
[docs]
def plot_first_weight_matrix(model, abs_values=False, ch_names=None, use_case=None, save_path=None):
"""
Visualize the model's first weight matrix and subsequent slices of it
as heatmaps, to compare with the sensitivity plots.
Parameters
----------
model : nn.module
Model consisting of nn.Modules
abs_values : bool, optional
If the absolute values of the weight matrix entries should be plotted.
The default is False.
save_path : string, optional
Path to save the plots. The default is None, i.e. the plots are not saved.
"""
if save_path:
base_path = os.path.abspath(BASE_PATH)
use_case_str = use_case if use_case else ''
save_path = os.path.join(base_path, 'Evaluation_plots', use_case_str, save_path)
os.makedirs(save_path, exist_ok=True)
for name, param in model.named_parameters():
if 'weight' in name:
# print(name)
# print(param.shape)
weight_matrix = param.detach().numpy()
break
input_channels, pred_size = model.input_channels, model.pred_size
out_features, in_features = weight_matrix.shape
rec_start_idx = input_channels*model.window_size
x_tks_feat = list(range(0, rec_start_idx, model.window_size)) + \
list(range(rec_start_idx, in_features+1, model.rnn_window))
y_tks_feat = list(range(0, out_features+1, out_features//4))
ch_size = input_channels + pred_size
ch_str = [f'Inp_ch_{i}' for i in range(input_channels)] + \
[f'Rec_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names
inp_str = _convert_to_latex(ch_names, ch_str, use_case)[:input_channels]
out_str = [f'Out_ch_{i}' for i in range(pred_size)] if ch_names is None else ch_names[-pred_size:]
out_str = _convert_to_latex(ch_names, out_str, use_case)[-pred_size:]
angle = 65 if ch_size > 3 else 0
fig_size = (15, 9) if ch_size < 5 else (15, 11)
colormap = 'coolwarm'
abs_str = 'signed values'
if abs_values:
weight_matrix = np.abs(weight_matrix)
colormap = 'jet'
abs_str = 'absolute values'
input_ch_matrix = weight_matrix[:, :rec_start_idx]
rec_ch_matrix = weight_matrix[:, rec_start_idx:]
fig, axs = plt.subplots(2,2, figsize=fig_size)
pos0 = axs[0,0].imshow(weight_matrix, cmap=colormap, interpolation='none',
aspect='auto', extent=[0, in_features, 0, out_features])
fig.colorbar(pos0, ax=axs[0,0])
axs[0,0].axvline(rec_start_idx, color='black', linewidth=2)
axs[0,0].set_xticks(x_tks_feat)
axs[0,0].set_yticks(y_tks_feat)
axs[0,0].grid(axis='x', color='black', lw=0.5)
pos1 = axs[0,1].imshow(input_ch_matrix, cmap=colormap, interpolation='none',
aspect='auto', extent=[0, rec_start_idx, 0, out_features])
fig.colorbar(pos1, ax=axs[0,1])
axs[0,1].set_xticks(range(0, rec_start_idx+1, model.window_size))
if ch_names:
axs[0,1].set_xticklabels(inp_str + [''], rotation=angle, ha='center')
axs[0,1].set_yticks(y_tks_feat)
axs[0,1].grid(axis='x', color='black', lw=0.5)
pos2 = axs[1,0].imshow(rec_ch_matrix, cmap=colormap, interpolation='none',
aspect='auto', extent=[0, in_features-rec_start_idx, 0, out_features])
fig.colorbar(pos2, ax=axs[1,0])
axs[1,0].set_xticks(range(0, in_features-rec_start_idx+1, model.rnn_window))
if ch_names:
axs[1,0].set_xticklabels(out_str + [''], ha='left')
axs[1,0].set_yticks(y_tks_feat)
axs[1,0].grid(axis='x', color='black', lw=0.5)
# delete empty subplot
fig.delaxes(axs[1,1])
for ax in axs.flatten():
ax.set_xlabel(r'Input Features $[-]$', labelpad=5)
ax.set_ylabel(r'Output Features $[-]$', labelpad=10)
if save_path:
fig.tight_layout()
fig.subplots_adjust(wspace=0.2, hspace=0.35)
plt.savefig(os.path.join(save_path, 'weight_matrix_one.pdf'), bbox_inches='tight')
else:
plt.suptitle(f'1st Weight Matrix and Slices of Input & Recurrent Channels ({abs_str})', fontsize=13)
axs[0,0].set_title(r'$1^{st}$ Weight Matrix (full)')
axs[0,1].set_title('Weight Matrix slice of Input Channels')
axs[1,0].set_title('Weight Matrix slice of Recurrent Channels')
fig.tight_layout()
fig.subplots_adjust(wspace=0.2, hspace=0.35)
plt.show()
[docs]
def plot_all_weight_matrices(model, abs_values=False, rel_threshold=0.1, use_case=None, save_path=None):
"""
Visualize all weight matrices of the model as heatmaps and their distributions
as histograms, highlighting the percentage of weights that fall within a
certain threshold range.
Parameters
----------
model : nn.module
Model consisting of nn.Modules
abs_values : bool, optional
If the absolute values of the weight matrix entries should be plotted.
The default is False.
rel_threshold : float, optional
Threshold value for the percentage of weights that fall within the range [-threshold, threshold].
The default is 0.1, i.e. 10 % of the maximum absolute value of each weight matrix.
use_case : string, optional
Use case for the channel names. The default is None.
save_path : string, optional
Path to save the plots. The default is None, i.e. the plots are not saved.
"""
if save_path:
base_path = os.path.abspath(BASE_PATH)
use_case_str = use_case if use_case else ''
save_path = os.path.join(base_path, 'Evaluation_plots', use_case_str, save_path)
os.makedirs(save_path, exist_ok=True)
weights = []
for name, param in model.named_parameters():
if 'weight' in name:
weights.append(param.detach().numpy())
if abs_values:
weights = [np.abs(w) for w in weights]
colormap = 'jet'
abs_str = 'absolute values'
else:
colormap = 'coolwarm'
abs_str = 'signed values'
w_len = len(weights)
num_cols = 2 if w_len <= 4 else 3
num_rows = int(np.ceil(w_len / num_cols))
fig_size = (12, 8) if num_cols == 2 else (15, 8)
# plot the values of the weight matrices as imshow plots
fig, axs = plt.subplots(num_rows, num_cols, figsize=fig_size)
rec_start_idx = model.input_channels * model.window_size
win_size = max(model.window_size, model.rnn_window)
for i, (w, ax) in enumerate(zip(weights, axs.flatten()[:w_len])):
out_features, in_features = w.shape
x_step = in_features//4 if in_features % 4 == 0 else in_features//5
if i == 0: # first matrix only
x_tks_feat = list(range(0, rec_start_idx, model.window_size)) + \
list(range(rec_start_idx, in_features+1, model.rnn_window))
else:
x_tks_feat = range(0, in_features+1, x_step)
y_step = out_features//4 if out_features > 4 else 1
y_tks_feat = range(0, out_features+1, y_step)
pos = ax.imshow(w, aspect='auto', cmap=colormap, interpolation='none', extent=[0, in_features, 0, out_features])
fig.colorbar(pos, ax=ax)
threshold = rel_threshold * max(abs(w.min()), abs(w.max()))
ax.imshow(np.ma.masked_outside(w, -threshold, threshold), cmap='Greens', alpha=0.2,
aspect='auto', interpolation='none', extent=[0, in_features, 0, out_features])
ax.set_xticks(x_tks_feat)
if i == 0 and in_features//win_size > 5:
for label in ax.get_xticklabels()[1:model.input_channels+1:2]:
label.set_visible(False)
for label in ax.get_xticklabels()[-model.pred_size:-1]:
label.set_visible(False)
ax.set_yticks(y_tks_feat)
ax.set_xlabel(r'Input features $[-]$', labelpad=5)
ax.set_ylabel(r'Output features $[-]$', labelpad=5)
ax.set_title(f'Weight Matrix {i+1}')
# delete empty subplots
num_del_axes = num_rows*num_cols - w_len
for i in range(1, num_del_axes+1):
fig.delaxes(axs.flatten()[-i])
if save_path:
fig.tight_layout()
fig.subplots_adjust(hspace=0.4, wspace=0.2)
plt.savefig(os.path.join(save_path, 'weight_matrices.pdf'), bbox_inches='tight')
else:
plt.suptitle(f'Weight matrices ({abs_str})', fontsize=13)
fig.tight_layout()
fig.subplots_adjust(hspace=0.4, wspace=0.2)
plt.show()
# create own legend entries for the histograms
# legend_entries = [Patch(facecolor='skyblue', edgecolor='black', label='Inside'),
# Patch(facecolor='mediumaquamarine', edgecolor='black', label='Outside')]
# plot the distributions of the weight matrices as histograms
fig, axs = plt.subplots(num_rows, num_cols, figsize=fig_size)
for i, (w, ax) in enumerate(zip(weights, axs.flatten()[:w_len])):
counts, bins, patches = ax.hist(w.flatten(), bins=40, density=True, color='skyblue', edgecolor='black')
max_value = max(abs(w.min()), abs(w.max()))
threshold = rel_threshold * max_value
# color all those bins that fall in the range [-threshold, threshold] in green
for j in range(len(bins)-1):
if bins[j] >= -threshold and bins[j+1] <= threshold:
patches[j].set_facecolor('mediumaquamarine')
# compute how many percent of the weights fall in the range [-threshold, threshold]
percentage = np.sum(counts[np.logical_and(bins[:-1] >= -threshold, bins[1:] <= threshold)])
print(f'Weight Matrix {i+1}')
print(f'Max. value: {max_value:.3f}, 10% Threshold: +-{threshold:.3f}')
print(f'Percentage of weights within threshold range: {percentage:.1f}%\n')
ax.set_xlabel(r'Weight value $[-]$', labelpad=5)
ax.set_ylabel(r'Density $[\%]$', labelpad=5)
ax.set_title(f'Weight Matrix {i+1}')
# ax.legend(handles=legend_entries, title=rf'$\mathcal{{T}}$ $\in$ [{-threshold:.3f}, {threshold:.3f}]', loc='upper left')
# delete empty subplots
num_del_axes = num_rows*num_cols - w_len
for i in range(1, num_del_axes+1):
fig.delaxes(axs.flatten()[-i])
if save_path:
fig.tight_layout()
fig.subplots_adjust(hspace=0.4, wspace=0.3)
plt.savefig(os.path.join(save_path, 'weight_distributions.pdf'), bbox_inches='tight')
else:
plt.suptitle('Weight distributions of the matrices', fontsize=13)
fig.tight_layout()
fig.subplots_adjust(hspace=0.4, wspace=0.3)
plt.show()
'''
def plot_quantiles(model, track, t_start, t_end, output=0, fs=10, show_legend=False, title=None, title_info=None, fig_path=None, is_duffing=True, plot_all=False, show=True):
""" Runs the QR prediction on track and plots the quantiles
Parameters
----------
model: QuantileNARX
QR model to use for prediction
track: torch.Dataloader
Single track
t_start: int, optional
Start of plotting window. The default is None.
t_end: int, optional
End of plotting window. The default is None.
output: int, optional
Output sensor to plot. The default is 0.
fs: int, optional
Sampling rate of dataset to rescale x axis. The default is 10.
show_legend : bool, optional
Show legend in the plot. The default is False
title: string, optional
Title for the plot. The default is None
title_info: dict[string,string], optional
Must contain keys ['dataset', 'model', 'track', 'sensor']. The default is None.
fig_path : string, optional
Path to save fig. The default is None
is_duffing: bool, optional
Whether it is position prediction (like Duffing dataset) or acceleration prediction. The default is True.
plot_all: bool, optional
Whether to plot all quantiles or only the 67.5% and 95% PI. The default is False.
show: bool, optional
If the plot should be displayed. The default is True.
Returns
-------
None.
"""
sns.set(style="white")
pred = model.prediction(track)
predicted_quantiles = [x[output] for x in pred]
ground_truth = torch.tensor([data[1][0][output][0] for data in track])
median = predicted_quantiles[0]
mse = nn.MSELoss()(median, ground_truth)
print(f"MSE: {mse}")
window = slice(t_start, t_end)
length = len(median[window])
x = np.arange(0, length) + (0 if not t_start else t_start)
x = x/fs
sns.lineplot(x=x, y=ground_truth[window], color=sns.color_palette()[1], label="Ground Truth", legend=show_legend)
sns.lineplot(x=x, y=median[window], color=sns.color_palette()[0], label="Prediction ยต", legend=show_legend)
if plot_all:
for lb, ub in zip(predicted_quantiles[1::2], predicted_quantiles[2::2]):
plt.fill_between(x, lb[window], ub[window], color=sns.color_palette()[0], alpha=0.2)
else:
# Only plot the 67.5% and 95% PI (similar to first and second std of Gaussian)
plt.fill_between(x, predicted_quantiles[1][window], predicted_quantiles[2][window], color=sns.color_palette()[0], alpha=0.36666)
plt.fill_between(x, predicted_quantiles[13][window], predicted_quantiles[14][window], color=sns.color_palette()[0], alpha=0.5)
export_plot(show_legend, title, title_info, fig_path, is_duffing, show)
'''