# -*- coding: utf-8 -*-
"""
Created on Fri Mar 11 17:26:16 2022
@author: WET2RNG
"""
import math
import inspect
from enum import Enum
import numpy as np
#from import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import (grad, Variable)
from torchinfo import summary
# from captum.attr import LRP
from scipy.stats.qmc import LatinHypercube
from softsensor.losses import GaussianNLLLoss
from softsensor.model import (CNN, Feed_ForwardNN, Freq_Att_CNN,
_filter_parameters)
from softsensor.train_model import train_model
class SensitivityMethods(Enum):
"""
Enumeration class, representing all currently available sensitivity methods.
Applicable methods are: 'gradient' and 'perturbation'.
"""
GRADIENT = 'gradient'
SMOOTH_GRAD = 'smooth_grad'
INTEGRATED_GRADIENT = 'integrated_gradient'
PERTURBATION = 'perturbation'
class _Autoregressive_Model(nn.Module):
"""
Parent class for all Autoregressive models
Parameters
----------
input_channels : int
Number of input channels
pred_size : int
Number of predicted values
window_size : int
window size of the input. Number of Datapoints in the windowed
external excitation signal
rnn_window : int, optional
Window Size of the Recurrent Connection
Returns
-------
None.
"""
def __init__(self, input_channels, pred_size, window_size, rnn_window,
forecast):
super().__init__()
self.input_channels = input_channels
self.pred_size = pred_size
self.window_size = window_size
self.rnn_window = rnn_window
self.forecast = forecast
self.Type = 'AR'
self.Pred_Type = 'Point'
self.Ensemble = False
def prediction(self, dataloader, device='cpu', sens_params=None):
"""
Prediction of a whole Time Series
Parameters
----------
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
sens_params : dict, optional
Dictionary that contains the parameters for the sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'verbose' defines whether the information about the sensitivity analysis is printed.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
The default is None, i.e. no sensitivity analysis is computed.
Returns
-------
if loss_ft=None:
torch.Tensor
Torch Tensor of same langth as input
if loss_ft=torch loss function:
(torch.Tensor, loss)
tuple of Torch Tensor of same langth as input and loss
"""
return _predict_ARNN(self, dataloader, device, sens_params)
[docs]
class ARNN(_Autoregressive_Model):
"""
Autoregressive Neural Network with linear layers
.. math:: window_size = rnn_window = tau
.. math:: forecast = 1
Parameters
----------
input_channels : int
Number of input channels
pred_size : int
Number of predicted values
window_size : int
Size of the sliding window applied to the time series
rnn_window : int
Window Size of the Recurrent Connection before the DNN.
hidden_size : list of int or None, optional
List gives the size of hidden units. The default is None.
activation : str, optional
Activation function to activate the feature space.
The default is 'relu'.
bias : bool, optional
If True, bias weights are used. The default is True.
dropout : float [0,1], optional
Adds dropout Layers after each Linear Layer. The default is None.
forecast : int, optional
Size of the forecast. The default is 1
concrete_dropout : bool, optional
Whether to use normal or concrete dropout Layers if dropout is not None. The default is False
Returns
-------
None.
Examples
--------
>>> import softsensor.autoreg_models
>>> import torch
>>> m = softsensor.autoreg_models.ARNN(2, 1, 10, 10, [16, 8])
>>> input = torch.randn(32, 2, 10)
>>> rec_input = torch.randn(32, 1, 10)
>>> output = m(input, rec_input)
>>> print(output.shape)
torch.Size([32, 1, 1])
>>> import softsensor.meas_handling as ms
>>> import numpy as np
>>> import pandas as pd
>>> t = np.linspace(0, 1.0, 101)
>>> d = {'inp1': np.random.randn(101),
'inp2': np.random.randn(101),
'out': np.random.randn(101)}
>>> handler = ms.Meas_handling([pd.DataFrame(d, index=t)], ['train'],
['inp1', 'inp2'], ['out'], fs=100)
>>> loader = handler.give_list(window_size=10, keyword='training',
rnn_window=10, batch_size=1)
>>> pred = m.prediction(loader[0])
>>> print(pred.shape)
torch.Size([1, 101])
"""
def __init__(self, input_channels, pred_size, window_size, rnn_window,
hidden_size=None, activation='relu', bias=True, dropout=None,
forecast=1, concrete_dropout=False, bn=False):
_Autoregressive_Model.__init__(self, input_channels, pred_size,
window_size, rnn_window, forecast)
self.params = _filter_parameters(locals().copy())
self.activation = activation
flatten_size = window_size*input_channels+rnn_window*pred_size
# Define Linear Network
self.DNN = Feed_ForwardNN(flatten_size, pred_size*forecast,
hidden_size, activation=activation,
bias=bias, dropout=dropout, concrete_dropout=concrete_dropout,
bn=bn)
[docs]
def forward(self, inp, x_rec):
"""
Forward function to propagate through the network
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels, window_size]
x_rec : torch.tensor, dtype=torch.float
Recurrent Input for forward Propagation.
shape=[batch size, pred_size, rnn_window]
Returns
-------
output: torch.tensor dtype=torch.float()
shape=[batch size, pred_size, forecast]
"""
inp = torch.flatten(inp, start_dim=1)
x_rec = torch.flatten(x_rec, start_dim=1)
pred = self.DNN(torch.cat((inp, x_rec), dim=1))
pred = pred.reshape(-1, self.pred_size, self.forecast)
return pred
[docs]
def forward_sens(self, inp):
"""
Forward function to propagate through the network, but only with one input tensor
that is already concatenated to allow for gradient-based sensitivity analysis
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation that is already concatenated,
shape=[batch size, external channels*window_size + pred_size*rnn_window]
Returns
-------
output: torch.tensor dtype=torch.float()
shape=[batch size, pred_size, forecast]
"""
pred = self.DNN(inp)
pred = pred.reshape(-1, self.pred_size, self.forecast)
return pred
[docs]
def get_recurrent_weights(self):
"""
Function that returns the weight that effect the Recurrent input of the
Network
Returns
-------
recurrent_weights : list of weight Tensors
List of the Weights that effect the Recurren input of the Network.
Example
-------
Based on the example in the introduction
>>> rec_w = m.get_recurrent_weights()
>>> print(rec_w[0].shape)
torch.Size([16, 10])
>>> print(rec_w[1].shape)
torch.Size([8, 16])
>>> print(rec_w[2].shape)
torch.Size([1, 8])
"""
input_pred_slice = slice(-self.rnn_window*self.pred_size, None)
return _get_recurrent_weights(self.DNN.named_parameters(), input_pred_slice)
[docs]
class DensityEstimationARNN(ARNN):
"""
ARNN with two outputs to predict mean and variance (aleatoric uncertainty)
Parameters
----------
input_channels : int
Number of input channels
pred_size : int
Number of predicted values
window_size : int
Size of the sliding window applied to the time series
rnn_window : int
Window Size of the Recurrent Connection before the DNN.
hidden_size : list of int or None, optional
List gives the size of hidden units. The default is None.
activation : str, optional
Activation function to activate the feature space.
The default is 'relu'.
bias : bool, optional
If True, bias weights are used. The default is True.
dropout : float [0,1], optional
Adds dropout Layers after each Linear Layer. The default is None
forecast : int, optional
Size of the forecast. The default is 1
concrete_dropout : bool, optional
Whether to use normal or concrete dropout Layers if dropout is not None. The default is False
Returns
-------
None.
Examples
--------
>>> import softsensor.autoreg_models
>>> import torch
>>> params = {'input_channels': 2,
'pred_size': 1,
'window_size': 10,
'rnn_window': 10}
>>> m = softsensor.autoreg_models.DensityEstimationARNN(**params, hidden_size=[16, 8])
>>> input = torch.randn(32, 2, 10)
>>> rec_input = torch.randn(32, 1, 10)
>>> output = m(input, rec_input)
>>> print(output[0].shape) #Mean Prediction
torch.Size([32, 1, 1])
>>> print(output[1].shape) #Var Prediction
torch.Size([32, 1, 1])
"""
def __init__(self, input_channels, pred_size, window_size, rnn_window,
hidden_size=None, activation='relu', bias=True, dropout=None, forecast=1, concrete_dropout=False,
bn=False):
# Define Linear Network with twice the forecast (mean and var of Gaussian)
ARNN.__init__(self, input_channels, pred_size, window_size, rnn_window,
hidden_size, activation, bias, dropout, 2*forecast, concrete_dropout, bn)
self.params = _filter_parameters(locals().copy())
self.forecast = forecast
self.Pred_Type = 'Mean_Var'
[docs]
def forward(self, inp, x_rec):
"""
Forward function to propagate through the network
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels, window_size]
x_rec : torch.tensor, dtype=torch.float
Recurrent Input for forward Propagation.
shape=[batch size, pred_size, rnn_window]
Returns
-------
mean: torch.tensor dtype=torch.float()
shape=[batch size, pred_size, forecast]
var torch.tensor dtype=torch.float() in [0,1]
shape=[batch size, pred_size, forecast]
"""
inp = torch.flatten(inp, start_dim=1)
x_rec = torch.flatten(x_rec, start_dim=1)
pred = self.DNN(torch.cat((inp, x_rec), dim=1))
pred = pred.reshape(-1, self.pred_size, self.forecast, 2)
mean, hidden_std = pred[:,:,:,0], pred[:,:,:,1]
var = F.softplus(hidden_std)
return mean, var
[docs]
def forward_sens(self, inp):
"""
Forward function to propagate through the network, but only with one input tensor
that is already concatenated to allow for gradient-based sensitivity analysis
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels*window_size + pred_size*rnn_window]
Returns
-------
mean: torch.tensor dtype=torch.float()
shape=[batch size, pred_size, forecast]
var torch.tensor dtype=torch.float() in [0,1]
shape=[batch size, pred_size, forecast]
"""
pred = self.DNN(inp)
pred = pred.reshape(-1, self.pred_size, self.forecast, 2)
mean, hidden_std = pred[:,:,:,0], pred[:,:,:,1]
var = F.softplus(hidden_std)
return mean, var
[docs]
def estimate_uncertainty(self, inp, x_rec):
"""
Wrapper of forward pass
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels, window_size]
x_rec : torch.tensor, dtype=torch.float
Recurrent Input for forward Propagation.
shape=[batch size, pred_size, rnn_window]
Returns
-------
(mean, var)
mean: torch.tensor dtype=torch.float()
shape=[batch size, pred_size]
var: torch.tensor dtype=torch.float() in [0,1]
shape=[batch size, pred_size]
"""
return self(inp, x_rec)
[docs]
def estimate_uncertainty_mean_std(self, inp, x_rec):
return self(inp, x_rec)
[docs]
def prediction(self, dataloader, device='cpu', sens_params=None):
"""
Prediction of a whole Time Series
Parameters
----------
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
sens_params : dict, optional
Dictionary that contains the parameters for the sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'verbose' defines whether the information about the sensitivity analysis is printed.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
The default is None, i.e. no sensitivity analysis is computed.
Returns
-------
if loss_ft=None:
(torch.Tensor, list[torch.Tensor])
tuple of Torch Tensor of same length as input and var
if loss_ft=torch loss funciton:
(torch.Tensor, list[torch.Tensor], loss)
tuple of Torch Tensor of same length as input, var and loss
"""
return _predict_arnn_uncertainty(self, dataloader, device, sens_params)
[docs]
def get_recurrent_weights(self):
"""
Function that returns the weight that effect the Recurrent input of the
Network (mean network)
Returns
-------
recurrent_weights : list of weight Tensors
List of the Weights that effect the Recurrent input of the Network.
"""
input_pred_slice = slice(-self.rnn_window*self.pred_size, None)
return _get_recurrent_weights(self.DNN.named_parameters(), input_pred_slice, True,
self.pred_size, self.forecast)
[docs]
class SeparateMVEARNN(ARNN):
"""
ARNN with two independent subnetworks to predict mean and variance (aleatoric uncertainty)
.. image:: C:/Users/wet2rng/Desktop/Coding/SoftSensor/doc/img/Separate_MVE.png
Parameters
----------
input_channels : int
Number of input channels
pred_size : int
Number of predicted values
window_size : int
Size of the sliding window applied to the time series
rnn_window : int
Window Size of the Recurrent Connection before the DNN.
mean_model : torch.Module
Model for point prediction
var_hidden_size : list[int] or None, optional
List gives the size of hidden variance network units. The default is None.
activation : str, optional
Activation function to activate the feature space.
The default is 'relu'.
bias : bool, optional
If True, bias weights are used. The default is True.
dropout : float [0,1], optional
Adds dropout Layers after each Linear Layer. The default is None
forecast : int, optional
Size of the forecast. The default is 1
concrete_dropout : bool, optional
Whether to use normal or concrete dropout Layers if dropout is not None. The default is False
Returns
-------
None.
Note
-------
See "Optimal Training of Mean Variance Estimation Neural Networks"
[Sluijterman et al. 2023 https://arxiv.org/abs/2302.08875]
Examples
--------
>>> import softsensor.autoreg_models
>>> import torch
>>> params = {'input_channels': 2,
'pred_size': 1,
'window_size': 10,
'rnn_window': 10}
>>> mean_model = softsensor.autoreg_models.ARNN(**params, hidden_size=[16, 8])
>>> m = softsensor.autoreg_models.SeparateMVEARNN(**params,mean_model=mean_model,
var_hidden_size=[16, 8])
>>> input = torch.randn(32, 2, 10)
>>> rec_input = torch.randn(32, 1, 10)
>>> output = m(input, rec_input)
>>> print(output[0].shape) #Mean Prediction
torch.Size([32, 1, 1])
>>> print(output[1].shape) #VarPrediction
torch.Size([32, 1, 1])
"""
def __init__(self, input_channels, pred_size, window_size, rnn_window, mean_model, var_hidden_size=None,
activation='relu', bias=True, dropout=None, forecast=1, concrete_dropout=False, bn=False):
# Network for variance prediction
ARNN.__init__(self, input_channels, pred_size, window_size, rnn_window,
var_hidden_size, activation, bias, dropout, forecast, concrete_dropout, bn)
self.params = _filter_parameters(locals().copy())
# Network for mean prediction
self.mean_model = mean_model
self.Pred_Type = 'Mean_Var'
[docs]
def forward(self, inp, x_rec):
"""
Forward function to propagate through the MVE network
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels, window_size]
x_rec : torch.tensor, dtype=torch.float
Recurrent Input for forward Propagation.
shape=[batch size, pred_size, rnn_window]
Returns
-------
mean: torch.tensor dtype=torch.float()
shape=[batch size, pred_size, forecast]
var: torch.tensor dtype=torch.float() in [0,1]
shape=[batch size, pred_size, forecast]
"""
inp = torch.flatten(inp, start_dim=1)
x_rec = torch.flatten(x_rec, start_dim=1)
mean = self.mean_model(inp, x_rec)
hidden_std = self.DNN(torch.cat((inp, x_rec), dim=1))
hidden_std = hidden_std.reshape(-1, self.pred_size, self.forecast)
var = F.softplus(hidden_std)
return mean, var
[docs]
def forward_sens(self, inp):
"""
Forward function to propagate through the network, but only with one input tensor
that is already concatenated to allow for gradient-based sensitivity analysis
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels*window_size + pred_size*rnn_window]
Returns
-------
mean: torch.tensor dtype=torch.float()
shape=[batch size, pred_size, forecast]
var torch.tensor dtype=torch.float() in [0,1]
shape=[batch size, pred_size, forecast]
"""
mean = self.mean_model.forward_sens(inp)
hidden_std = self.DNN(inp)
hidden_std = hidden_std.reshape(-1, self.pred_size, self.forecast)
var = F.softplus(hidden_std)
return mean, var
[docs]
def estimate_uncertainty(self, inp, x_rec):
"""
Wrapper of forward pass
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels, window_size]
x_rec : torch.tensor, dtype=torch.float
Recurrent Input for forward Propagation.
shape=[batch size, pred_size, rnn_window]
Returns
-------
(mean, var)
mean: torch.tensor dtype=torch.float()
shape=[batch size, pred_size]
var: torch.tensor dtype=torch.float() in [0,1]
shape=[batch size, pred_size]
"""
return self(inp, x_rec)
[docs]
def estimate_uncertainty_mean_std(self, inp, x_rec):
"""
Wrapper of forward pass
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels, window_size]
x_rec : torch.tensor, dtype=torch.float
Recurrent Input for forward Propagation.
shape=[batch size, pred_size, rnn_window]
Returns
-------
(mean, var)
mean: torch.tensor dtype=torch.float()
shape=[batch size, pred_size]
var: torch.tensor dtype=torch.float() in [0,1]
shape=[batch size, pred_size]
"""
return self(inp, x_rec)
[docs]
def prediction(self, dataloader, device='cpu', sens_params=None):
"""
Prediction of a whole Time Series
Parameters
----------
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
sens_params : dict, optional
Dictionary that contains the parameters for the sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'verbose' defines whether the information about the sensitivity analysis is printed.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
The default is None, i.e. no sensitivity analysis is computed.
Returns
-------
if loss_ft=None:
(torch.Tensor, list[torch.Tensor])
tuple of Torch Tensor of same length as input and var
if loss_ft=torch loss function:
(torch.Tensor, list[torch.Tensor], loss)
tuple of Torch Tensor of same length as input, var and loss
"""
return _predict_arnn_uncertainty(self, dataloader, device, sens_params)
[docs]
def get_recurrent_weights(self):
"""
Function that returns the weight that effect the Recurrent input of the
Network (mean network)
Returns
-------
recurrent_weights : list of weight Tensors
List of the Weights that effect the Recurrent input of the Network.
"""
return self.mean_model.get_recurrent_weights()
[docs]
class QuantileARNN(ARNN):
"""
ARNN with multiple outputs to predict quantiles
Parameters
----------
input_channels : int
Number of input channels
pred_size : int
Number of predicted values
window_size : int
Size of the sliding window applied to the time series
rnn_window : int
Window Size of the Recurent Connection before the DNN.
hidden_sizes : list of three lists of int or None, optional
[hidden_mean_size, hidden_var_size, hidden_shared_size]
List gives the size of hidden mean, variance and shared network units. The default is None.
activation : str, optional
Activation function to activate the feature space.
The default is 'relu'.
bias : bool, optional
If True, bias weights are used. The default is True.
dropout : float [0,1], optional
Adds dropout Layers after each Linear Layer. The default is None
forecast : int, optional
Size of the forecast. The default is 1
concrete_dropout : bool, optional
Whether to use normal or concrete dropout Layers if dropout is not None. The default is False
n_quantiles: int, optional
Number of quantiles to predict. The default is 39 (median and 19 PIs between 0 and 1)
mean_model : torch.Module, optional
Model for point prediction. The default is None
Returns
-------
None.
"""
def __init__(self, input_channels, pred_size, window_size, rnn_window,
hidden_size=None, activation='relu', bias=True, dropout=None, forecast=1, concrete_dropout=False, n_quantiles=39, mean_model=None, bn=False):
ARNN.__init__(self, input_channels, pred_size, window_size, rnn_window,
hidden_size, activation, bias, dropout, n_quantiles*forecast, concrete_dropout, bn)
self.params = _filter_parameters(locals().copy())
self.n_quantiles = n_quantiles
self.forecast = forecast
self.n_layers = len(hidden_size) if hidden_size else 0
self.mean_model = mean_model
self.Pred_Type = 'Quantile'
[docs]
def forward(self, inp, x_rec):
"""
Forward function to propagate through the quantile network
If mean_model is not None but a point prediction model, the mean_model is used for point prediction
This is useful to keep the point prediction frozen during training without teacher forcing
Parameters
----------
inp : torch.tensor dtype=torch.float
Input tensor for forward propagation,
shape=[batch size, external channels, window_size]
x_rec : torch.tensor, dtype=torch.float
Recurrent Input for forward Propagation.
shape=[batch size, pred_size, rnn_window]
Returns
-------
pred: torch.tensor dtype=torch.float()
shape=[batch size, pred_size, n_quantiles]
"""
inp = torch.flatten(inp, start_dim=1)
x_rec = torch.flatten(x_rec, start_dim=1)
pred = self.DNN(torch.cat((inp, x_rec), dim=1))
pred = pred.reshape(-1, self.pred_size, self.forecast, self.n_quantiles)
if self.mean_model:
pred[...,:,0] = self.mean_model(inp, x_rec)
return pred
[docs]
def prediction(self, dataloader, device='cpu'):
"""
Prediction of a whole Time Series
Parameters
----------
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
Returns
-------
if loss_ft=None:
quantiles: list[torch.Tensor]
list of n_quantile tensors of same length as input
if loss_ft=torch loss funciton:
(list[torch.Tensor], loss)
list of n_quantile tensors of same length as input and loss
"""
return _prediction(self, dataloader, device)
[docs]
def get_recurrent_weights(self):
"""
Function that returns the weight that effect the Recurrent input of the
Network (mean network)
Returns
-------
recurrent_weights : list of weight Tensors
List of the Weights that effect the Recurrent input of the Network.
"""
input_pred_slice = slice(-self.rnn_window*self.pred_size, None)
return _get_recurrent_weights(self.DNN.named_parameters(), input_pred_slice,
distribution_layer=True)
'''
helpers
'''
def _get_recurrent_weights(parameters, input_pred_slice=None, distribution_layer=False, pred_size=1, forecast=1):
Layer = 0
recurrent_weights = []
for name, W in parameters:
if 'weight' in name:
if input_pred_slice and Layer == 0:
temp_weights = W[:, input_pred_slice]
recurrent_weights.append(temp_weights)
else:
recurrent_weights.append(W)
Layer += 1
if distribution_layer:
#recurrent_weights[-1] = recurrent_weights[-1][0, :]
recurrent_weights[-1] = recurrent_weights[-1][:pred_size*forecast, :]
return recurrent_weights
def _forward_AR(model, inp, x_rec):
inp = torch.cat([inp, x_rec], dim=1)
inp = model.ConvNet(inp)
if model.bn:
inp = model.BNLayer(inp)
inp = torch.flatten(inp, start_dim=1)
inp = model.DNN(inp)
inp = inp.reshape(-1, model.pred_size, model.forecast)
return inp
def _predict_ARNN(model, dataloader, device='cpu', sens_params=None):
"""
Predict function for forward ARNN models
Parameters
----------
model : Model consisting of nn.Modules
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
sens_params : dict, optional
Dictionary that contains the parameters for the sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'verbose' defines whether the information about the sensitivity analysis is printed.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
(If not a multiple of the model's forecast, the number will be rounded up to the next multiple.)
The default is None, i.e. no sensitivity analysis is computed.
Returns
-------
if comp_sens is False:
torch.Tensor : Tensor of same langth as input, containing the predictions.
if comp_sens is True:
(torch.Tensor, dict) : Tuple of Tensor of same length as input and sensitivity dict.
Key is the prediction type, value is the sensitivity tensor.
"""
fc = model.forecast
# The Model starts with zeros as recurrent System state
size = max(model.window_size, model.rnn_window)
prediction = torch.zeros((model.pred_size, size))
original_out = torch.zeros((model.pred_size, size))
x_rec = torch.zeros(1, model.pred_size, model.rnn_window)
# Tensors to device
x_rec = x_rec.to(device)
prediction = prediction.to(device)
original_out = original_out.to(device)
model.to(device)
checks = _check_sens_params_pred(sens_params)
if sens_params:
method, comp_sens, verbose, sens_length, num_samples, std_dev, correlated = checks[:-2]
else:
comp_sens, verbose = checks # False
# Initialise 3D tensor for sensitivity analysis
if comp_sens:
flatten_size = model.window_size*model.input_channels + model.rnn_window*model.pred_size
num_timesteps = len(dataloader)*fc
loader_length = num_timesteps - dataloader.dataset.subclass.add_zero
sens_indices = np.arange(len(dataloader))
if sens_length:
num_timesteps, sens_indices = _random_subset_sens_indices(sens_length, fc, model.Type, dataloader)
sensitivity = torch.zeros((num_timesteps, model.pred_size, flatten_size))
if verbose:
print(f'Shape of sensitivity tensor: {sensitivity.shape}')
print(f'Start {method.upper()}-based Sensitivity Analysis...\n')
# Iterate over dataloader
idx = 0
for i, data in enumerate(tqdm(dataloader) if verbose else dataloader):
inputs, output = data
inputs, output = inputs[0].to(device), output.to(device)
# Prepare input for model to allow autograd computing gradients of outputs w.r.t. inputs
inputs = Variable(torch.flatten(inputs, start_dim=1), requires_grad=True)
x_rec = Variable(torch.flatten(x_rec, start_dim=1), requires_grad=True)
inp = torch.cat((inputs, x_rec), dim=1)
pred = model.forward_sens(inp) if comp_sens else model(inputs, x_rec)
if comp_sens and i in sens_indices:
sensitivity[idx:idx+fc] = _comp_sensitivity(method, model, inp, pred, num_samples, std_dev, correlated)
idx += fc
prediction = torch.cat((prediction,
pred.detach().reshape(model.pred_size, -1)), dim=1)
# Recurrent Input that is used for the next prediction -> autoregressive feedback!
x_rec = torch.unsqueeze(prediction[:, -model.rnn_window:], dim=0)
original_out = torch.cat((original_out,
output.reshape(model.pred_size, -1)), dim=1)
# cut zeros from initialisation
prediction = prediction[:, size:] # shape = [pred_size, len(dataloader)]
original_out = original_out[:, size:]
cut_zeros = dataloader.dataset.subclass.add_zero
if cut_zeros != 0:
prediction = prediction[:, :-cut_zeros]
original_out = original_out[:, :-cut_zeros]
prediction.cpu()
model.to('cpu')
if comp_sens:
sensitivity = sensitivity[:loader_length]
sensitivity_dict = {f'{model.Pred_Type}': sensitivity.cpu()}
if verbose:
print(f'{method.upper()}-based Sensitivity Analysis completed!\n')
return prediction, sensitivity_dict
else:
return prediction
def _comp_grad_sens(inputs, pred, pred_type, ensemble=False, random_samples=0, amplification=1):
"""
Compute the gradient-based sensitivity of the output w.r.t. the inputs in each timestep.
Parameters
----------
inputs : torch.Tensor
Input tensor with already concatenated external excitation and recurrent state signals,
shape=[batch_size, input_channels*window_size + pred_size*rnn_window]
pred : torch.Tensor
Output tensor that contains the predictions,
shape=[batch_size, pred_size, forecast]
pred_type : str
The model's prediction type out of ('Point, 'Mean_Var'), which defines
the number of outputs that the sensitivity analysis is performed on.
ensemble : bool, optional
If True, sensitivity is computed for an ensemble of models. The default is False.
random_samples : int, optional
Number of random samples, drawn from a standard normal distribution, to approximate the
expected sensitivity range/distribution across the aleatoric uncertainty of MVE models.
The default is 0 samples, i.e. no sampling is performed.
amplification : float, optional
Amplification factor for the uncertainty quantification of the sensitivity analysis.
Only applicable for MVE models and only used if random_samples > 0. The default is 1.
Returns
-------
sens_temp : torch.Tensor
Sensitivity tensor as Jacobian that contains the gradients of the output w.r.t. the inputs,
shape=[batch_size*forecast, pred_size, input_channels*window_size + pred_size*rnn_window]
sens_temp_mean : torch.Tensor, optional
Sensitivity tensor that contains the gradients of the mean output w.r.t. the inputs when using MVE models.
sens_temp_var : torch.Tensor, optional
Sensitivity tensor that contains the gradients of the aleatoric variance output w.r.t. the inputs when using MVE models.
"""
batch_size, pred_size, forecast = pred[0].shape if (pred_type == 'Mean_Var' or ensemble) else pred.shape
flatten_size = inputs.shape[1]
def jacobian(inputs, pred):
"""
Compute the Jacobian of the output w.r.t. the inputs for one timestep.
"""
# initialize Jacobian tensor with NaNs
jac = torch.full((batch_size*forecast, pred_size, flatten_size), float('nan'))
for k in range(forecast): # loop over forecasting horizon
for j in range(pred_size): # loop over output channels
# shape [batch_size, input_channels*window_size + pred_size*rnn_window]
grad_temp = grad(pred[:,j,k], inputs, grad_outputs=
torch.ones_like(pred[:,j,k]), create_graph=True)[0].detach()
if j == 0:
jac_temp = grad_temp.unsqueeze(1)
else:
jac_temp = torch.cat((jac_temp, grad_temp.unsqueeze(1)), dim=1)
jac[k::forecast] = jac_temp
return jac # shape [batch_size*forecast, pred_size, flatten_size]
# return jac * inputs.unsqueeze(1).detach() # gradient * input -> between pure gradient & IG!
# Compute and return the gradients
if pred_type == 'Point':
return jacobian(inputs, pred)
elif pred_type == 'Mean_Var' or ensemble:
mean_pred, var_pred = pred
sens_temp_mean = jacobian(inputs, mean_pred)
sens_temp_var = torch.zeros_like(sens_temp_mean)
if random_samples:
grads = torch.full((random_samples, batch_size*forecast, pred_size, flatten_size), float('nan'))
samples = torch.full((random_samples, pred_size), float('nan'))
for i in range(random_samples): # not direct sampling, but reparametrization trick: mean + eps*std, with eps ~ N(0,1)
eps = torch.randn(1, pred_size, 1) * amplification # used for later up-scaling
sampled_pred = mean_pred + eps * torch.sqrt(var_pred)
grads[i] = jacobian(inputs, sampled_pred)
samples[i] = eps[0,:,0].detach()
sens_temp_var = grads.mean(dim=0)
return sens_temp_mean, sens_temp_var, grads.mean(dim=1), samples
return sens_temp_mean, sens_temp_var
def _comp_smooth_grad_sens(model, inputs, pred, pred_type, ensemble=False, num_samples=20, std_dev=0.2):
"""
Compute the SmoothGrad-based sensitivity of the output w.r.t. the inputs in each timestep.
Parameters
----------
model : Model consisting of nn.Modules
inputs : torch.Tensor
Input tensor with already concatenated external excitation and recurrent state signals,
shape=[batch_size, input_channels*window_size + pred_size*rnn_window]
pred : torch.Tensor
Output tensor that contains the predictions,
shape=[batch_size, pred_size, forecast]
pred_type : str
The model's prediction type out of ('Point', 'Mean_Var'), which defines
the number of outputs that the sensitivity analysis is performed on.
ensemble : bool, optional
If True, sensitivity is computed for an ensemble of models. The default is False.
num_samples : int, optional
Number of noisy samples to generate. The default is 10.
std_dev : float, optional
Standard deviation used for sampling the noisy variations of the input. The default is 0.1.
Returns
-------
sens_temp : torch.Tensor
Sensitivity tensor as Jacobian that contains the gradients of the output w.r.t. the inputs,
shape=[batch_size*forecast, pred_size, flatten_size]
sens_temp_mean : torch.Tensor, optional
Sensitivity tensor that contains the gradients of the mean output w.r.t. the inputs when using MVE models.
sens_temp_var : torch.Tensor, optional
Sensitivity tensor that contains the gradients of the aleatoric variance output w.r.t.
the inputs when using MVE models.
Raises
------
AssertionError
If the number of samples is less than 1 or the standard deviation is less than 0.
"""
assert num_samples > 0, 'Number of samples must be greater than 0 for SmoothGrad!'
assert std_dev > 0, 'Standard deviation for Gaussian noise must be greater than 0 for SmoothGrad!'
batch_size, pred_size, forecast = pred[0].shape if (pred_type == 'Mean_Var' or ensemble) else pred.shape
flatten_size = inputs.shape[1]
def smooth_grad():
"""
Compute the SmoothGrad sensitivity of the output w.r.t. the inputs for one timestep.
"""
if pred_type == 'Point':
gradients = torch.full((num_samples, batch_size*forecast, pred_size, flatten_size), float('nan'))
else:
gradients_mean = torch.full((num_samples, batch_size*forecast, pred_size, flatten_size), float('nan'))
gradients_var = torch.full((num_samples, batch_size*forecast, pred_size, flatten_size), float('nan'))
for i in range(num_samples):
noise = torch.randn_like(inputs, requires_grad=True) * std_dev
inputs_noisy = inputs + noise
pred_noisy = model.forward_sens(inputs_noisy)
if i < num_samples-1:
grads = _comp_grad_sens(inputs_noisy, pred_noisy, pred_type, ensemble)
else: # compute gradients for the original input
grads = _comp_grad_sens(inputs, pred, pred_type, ensemble)
if pred_type == 'Point':
gradients[i] = grads
else:
gradients_mean[i], gradients_var[i] = grads
if pred_type == 'Point':
return gradients.mean(dim=0)
elif pred_type == 'Mean_Var' or ensemble:
return gradients_mean.mean(dim=0), gradients_var.mean(dim=0)
return smooth_grad()
def _comp_integrated_grad_sens(model, inputs, pred, pred_type, ensemble=False, num_steps=10):
"""
Compute the integrated gradient-based sensitivity of the output w.r.t. the inputs in each timestep.
Parameters
----------
model : Model consisting of nn.Modules
inputs : torch.Tensor
Input tensor with already concatenated external excitation and recurrent state signals,
shape=[batch_size, input_channels*window_size + pred_size*rnn_window]
pred : torch.Tensor
Output tensor that contains the predictions,
shape=[batch_size, pred_size, forecast]
pred_type : str
The model's prediction type out of ('Point, 'Mean_Var'), which defines
the number of outputs that the sensitivity analysis is performed on.
ensemble : bool, optional
If True, sensitivity is computed for an ensemble of models. The default is False.
num_steps : int, optional
Number of steps along the linearly interpolated path from the baseline to the input.
The default is 4 steps.
Returns
-------
sens_temp : torch.Tensor
Sensitivity tensor as Jacobian that contains the gradients of the output w.r.t. the inputs,
shape=[batch_size*forecast, pred_size, input_channels*window_size + pred_size*rnn_window]
sens_temp_mean : torch.Tensor, optional
Sensitivity tensor that contains the gradients of the mean output w.r.t. the inputs when using MVE models.
sens_temp_var : torch.Tensor, optional
Sensitivity tensor that contains the gradients of the aleatoric variance output w.r.t. the inputs when using MVE models.
Raises
------
AssertionError
If the number of steps is less than 2.
"""
assert num_steps > 1, 'Number of integration steps must be greater than 1 for IG!'
batch_size, pred_size, forecast = pred[0].shape if (pred_type == 'Mean_Var' or ensemble) else pred.shape
flatten_size = inputs.shape[1]
alphas = torch.linspace(0, 1, num_steps+1)[1:-1] # exclude the baseline bc it's zero
baseline_inp = torch.zeros_like(inputs) # zero baseline
# baseline_inp = torch.mean(inputs, dim=1, keepdim=True).repeat(1, flatten_size) # mean baseline as alternative option
difference_inp = (inputs - baseline_inp).repeat_interleave(forecast, dim=0).unsqueeze(1).detach()
def integrated_grads(alphas, baseline_inp, difference_inp):
"""
Compute the integrated gradients of the output w.r.t. the inputs for one timestep.
"""
gradients_mean = torch.full((num_steps, batch_size*forecast, pred_size, flatten_size), float('nan'))
if pred_type == 'Mean_Var' or ensemble:
gradients_var = torch.full((num_steps, batch_size*forecast, pred_size, flatten_size), float('nan'))
for i, alpha in enumerate(alphas):
inputs_interpol = baseline_inp + alpha * (inputs - baseline_inp) # interpolate between baseline and input
pred_interpol = model.forward_sens(inputs_interpol)
grads = _comp_grad_sens(inputs_interpol, pred_interpol, pred_type, ensemble)
gradients_mean[i] = grads[0] if pred_type == 'Mean_Var' or ensemble else grads
if pred_type == 'Mean_Var' or ensemble:
gradients_var[i] = grads[1]
grads = _comp_grad_sens(inputs, pred, pred_type, ensemble)
gradients_mean[-1] = grads[0] if pred_type == 'Mean_Var' or ensemble else grads
ig_mean = (gradients_mean[:-1] + gradients_mean[1:]) / 2.0 # trapezoidal integration rule
ig_mean = ig_mean.mean(dim=0) * difference_inp
# print(gradients_mean[:,0,0,50]) # check how gradients evolve along the path
if pred_type == 'Mean_Var' or ensemble:
gradients_var[-1] = grads[1]
ig_var = (gradients_var[:-1] + gradients_var[1:]) / 2.0
ig_var = ig_var.mean(dim=0) * difference_inp
return ig_mean, ig_var
return ig_mean # shape [batch_size*forecast, pred_size, flatten_size]
return integrated_grads(alphas, baseline_inp, difference_inp)
def _reshape_array(model, array, aggregation=None, remove_nans=False, repeat=False, repeat_size=None):
"""
Reshape a post-processed array of the sensitivity tensor for
further analysis, while keeping the information of possibly
different window sizes between input and recurrent signals
Parameters
----------
model : Model consisting of nn.Modules
array : np.ndarray
Post-processed array of the sensitivity tensor
aggregation : str, optional
Specifies the aggregation method for the reshaped array, performed
on its last axis. Choose from: mean, sum, rms.
The default is None, i.e. only reshaping and hstacking is performed.
remove_nans : bool, optional
If True, NaN values are removed from the reshaped array. The default is False.
repeat : bool, optional
If True, the array is repeated after the aggregation (only if 1D!)
for each input and recurrent channel. The default is False.
repeat_size : int, optional
Size of the repetition, which is the same for every channel. The default is None.
Returns
-------
np.ndarray
Reshaped and hstacked array of the sensitivity tensor.
Raises
------
ValueError
If an invalid aggregation method is given.
"""
m_type = model.Type
win_size = max(model.window_size, model.rnn_window) if m_type in ['AR', 'AR_RNN'] else model.window_size
ch_size = model.input_channels + model.pred_size if m_type in ['AR', 'AR_RNN'] else model.input_channels
rec_start_idx = model.input_channels*model.window_size
# Reshape the array such that the input and recurrent signals with their corresponding window sizes are separated
if array.ndim == 1:
temp1 = array[:rec_start_idx].reshape(model.input_channels, model.window_size)
if m_type in ['AR', 'AR_RNN']:
temp2 = array[rec_start_idx:].reshape(model.pred_size, model.rnn_window)
elif array.ndim == 2:
temp1 = array[:, :rec_start_idx].reshape(array.shape[0], model.input_channels, model.window_size)
if m_type in ['AR', 'AR_RNN']:
temp2 = array[:, rec_start_idx:].reshape(array.shape[0], model.pred_size, model.rnn_window)
elif array.ndim == 3:
temp1 = array[..., :rec_start_idx].reshape(*array.shape[:2], model.input_channels, model.window_size)
if m_type in ['AR', 'AR_RNN']:
temp2 = array[..., rec_start_idx:].reshape(*array.shape[:2], model.pred_size, model.rnn_window)
# Apply aggregation method if specified, otherwise only stack arrays together
if aggregation is None:
if temp1.ndim == 2:
temp = np.full((ch_size, win_size+1), np.nan)
temp[:model.input_channels, -model.window_size:] = temp1 # right-align all input channels
if m_type in ['AR', 'AR_RNN']:
temp[model.input_channels:, -model.rnn_window:] = temp2 # right-align all recurrent channels
elif temp1.ndim == 3:
temp = np.full((array.shape[0], ch_size, win_size+1), np.nan)
temp[:, :model.input_channels, -model.window_size:] = temp1
if m_type in ['AR', 'AR_RNN']:
temp[:, model.input_channels:, -model.rnn_window:] = temp2
if remove_nans:
temp = [x[~np.isnan(x)] for x in temp]
if temp1.ndim == 3:
inp_ch_lst = [i*model.window_size for i in range(1,model.input_channels)]
if m_type in ['AR', 'AR_RNN']:
rec_ch_lst = [rec_start_idx + i*model.rnn_window for i in range(model.pred_size)]
temp = [np.split(x, inp_ch_lst + rec_ch_lst) for x in temp]
else:
temp = [np.split(x, inp_ch_lst) for x in temp]
return temp
else:
if aggregation == 'mean':
agg1 = np.mean(temp1, axis=-1)
if m_type in ['AR', 'AR_RNN']:
agg2 = np.mean(temp2, axis=-1)
elif aggregation == 'median':
agg1 = np.median(temp1, axis=-1)
if m_type in ['AR', 'AR_RNN']:
agg2 = np.median(temp2, axis=-1)
elif aggregation == 'sum':
agg1 = np.sum(temp1, axis=-1)
if m_type in ['AR', 'AR_RNN']:
agg2 = np.sum(temp2, axis=-1)
elif aggregation == 'rms':
agg1 = np.sqrt(np.mean(np.square(temp1), axis=-1))
if m_type in ['AR', 'AR_RNN']:
agg2 = np.sqrt(np.mean(np.square(temp2), axis=-1))
else:
raise ValueError(f'Invalid aggregation method "{aggregation}" given! Choose from: mean, sum, rms.')
if agg1.ndim == 1 and repeat:
if repeat_size is None:
agg1 = agg1.repeat(model.window_size)
if m_type in ['AR', 'AR_RNN']:
agg2 = agg2.repeat(model.rnn_window)
else:
agg1 = agg1.repeat(repeat_size)
if m_type in ['AR', 'AR_RNN']:
agg2 = agg2.repeat(repeat_size)
if m_type in ['AR', 'AR_RNN']:
return np.concatenate((np.append(np.nan, agg1), agg2), axis=-1)
else:
return np.append(np.nan, agg1)
if m_type in ['AR', 'AR_RNN']:
return np.concatenate((agg1, agg2), axis=-1)
else:
return agg1
def _postprocess_sens(model, sensitivity):
"""
Postprocess the sensitivity tensor to get information about mean and std of the gradients,
aggregated over the timesteps, output channels and window sizes.
Parameters
----------
model : Model consisting of nn.Modules
sensitivity : torch.Tensor
Sensitivity tensor that contains the gradients of the output with respect to the inputs,
with shape=[len(data_loader), pred_size, input_channels*window_size + pred_size*rnn_window]
Returns
-------
sum_mean_feature : np.ndarray
RMS of the time-averaged sensitivities over all output channels for each input feature
sum_std_feature : np.ndarray
Sum of the std of the sensitivities over all output channels for each input feature
sum_inp_channels : np.ndarray
RMS of the time-avg. sensitivities for each input-output channel combination
std_inp_channels : np.ndarray
Sum of the std of the sensitivities for each input-output channel combination
rms_out_ch_sens : np.ndarray
RMS of the sensitivities over all output channels for each timestep and input feature
mean_out_ch_sens : np.ndarray
Mean of the sensitivities over all output channels for each timestep and input feature
"""
sensitivity = sensitivity.numpy()
# Compute the RMS across/over all output channels for each timestep
rms_out_ch_sens = np.sqrt(np.mean(np.square(sensitivity), axis=1)) # shape [len(dataloader), input_channels*window_size + pred_size*rnn_window]
mean_out_ch_sens = np.mean(sensitivity, axis=1)
# Compute the mean and var of the sensitivity tensor along the time axis
temp_mean_sens = np.mean(sensitivity, axis=0) # shape [pred_size, input_channels*window_size + pred_size*rnn_window]
temp_std_sens = np.std(sensitivity, axis=0)
# Compute mean-squared sensitivity (RMS) and mean-squared std.-dev. across all output channels
sum_mean_feature = np.sqrt(np.mean(temp_mean_sens**2, axis=0)) # shape [input_channels*window_size + pred_size*rnn_window]
sum_std_feature = np.sqrt(np.mean(temp_std_sens**2 + (temp_mean_sens - temp_mean_sens.mean(axis=0, keepdims=True))**2, axis=0))
# Compute the RMS across entire window size for each input/recurrent-output channel combination
sum_inp_channels = _reshape_array(model, temp_mean_sens, aggregation='rms') # shape [pred_size, input_channels+pred_size]
std_inp_channels = _reshape_array(model, temp_std_sens, aggregation='rms')
# Delete sensitivity tensor for less memory usage
del sensitivity
return (sum_mean_feature, sum_std_feature), (sum_inp_channels, std_inp_channels), (rms_out_ch_sens, mean_out_ch_sens)
def _pred_ARNN_batch(model, batch_sw, device='cpu', sens_params=None):
"""
Predict function for forward ARNN models with batched dataset for faster
computation.
Parameters
----------
model : Model consisting of nn.Modules
batch_sw : Batched Sliding Window
Dataset to compute output for
device : str, optional
device to compute on. The default is 'cpu'.
sens_params : dict, optional
Dictionary containing the parameters for sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
(If not a multiple of the model's forecast, the number will be rounded up to the next multiple.)
The default is None, i.e. no sensitivity analysis is performed.
Returns
-------
if comp_sens is False:
torch.Tensor : Tensor of same langth as input, containing the predictions.
if comp_sens is True:
(torch.Tensor, dict) : Tuple of Tensor of same length as input and sensitivity dict.
Key is the prediction type, value is the sensitivity tensor.
if comp_sens is True and random_samples > 0:
(torch.Tensor, dict, torch.Tensor, torch.Tensor) : Tuple of Tensor of same length as input,
sensitivity dict, uncertainty quantification tensor and random samples tensor.
"""
size = max(model.window_size, model.rnn_window)
forecast = model.forecast
rnn_size = model.rnn_window
w = batch_sw.__widths__()[0]
pred_warmup = torch.zeros((w, model.pred_size, size))
prediction = torch.full((w, model.pred_size, len(batch_sw)*forecast), float('nan'))
prediction = torch.cat((pred_warmup, prediction), dim=2)
if model.Pred_Type == 'Mean_Var':
var_prediction = torch.zeros(prediction.shape)
original_out = torch.full((w, model.pred_size, len(batch_sw)*forecast), float('nan'))
# Tensors to device
prediction = prediction.to(device)
original_out = original_out.to(device)
model.to(device)
checks = _check_sens_params_pred(sens_params)
if sens_params:
method, comp_sens, verbose, sens_length, num_samples, std_dev, correlated, random_samples, amplification = checks
else:
comp_sens, verbose = checks # False
# Initialise 3D tensor for sensitivity analysis
if comp_sens:
flatten_size = model.window_size*model.input_channels + model.rnn_window*model.pred_size
num_timesteps = len(batch_sw)*forecast
sens_indices = np.arange(len(batch_sw))
sens_uq, eps = None, None
if sens_length:
num_timesteps, sens_indices = _random_subset_sens_indices(sens_length, forecast,
model.Type, batch_sw, batched=True)
sens_mean = torch.full((w, num_timesteps, model.pred_size, flatten_size), float('nan'))
if model.Pred_Type == "Mean_Var":
sens_var = sens_mean.clone()
if random_samples:
sens_uq = torch.full((num_timesteps, random_samples, model.pred_size, flatten_size), float('nan'))
eps = torch.full((num_timesteps, random_samples, model.pred_size), float('nan'))
if verbose:
print(f'Start {method.upper()}-based Sensitivity Analysis...\n')
print(f'Shape of sensitivity tensor: {sens_mean.shape}')
# Iterate over dataloader
idx = 0
for i in tqdm(range(len(batch_sw))) if verbose else range(len(batch_sw)):
offset = i*forecast + size
inputs, output = batch_sw[i]
inputs, output = inputs[0].to(device), output.to(device)
original_out[:batch_sw.valid_sws[i], :,
i*forecast:(i+1)*forecast] = output
x_rec = prediction[:batch_sw.valid_sws[i], :,
(offset - rnn_size):offset]
# Prepare input for model to allow autograd computing gradients of outputs w.r.t. inputs
inputs = Variable(torch.flatten(inputs, start_dim=1), requires_grad=True)
x_rec = Variable(torch.flatten(x_rec, start_dim=1), requires_grad=True)
inp = torch.cat((inputs, x_rec), dim=1)
pred = model.forward_sens(inp) if comp_sens else model(inputs, x_rec)
if comp_sens and i in sens_indices:
sens_temp = _comp_sensitivity(method, model, inp, pred, num_samples, std_dev, correlated, random_samples, amplification)
if model.Pred_Type == "Mean_Var":
sens_mean[:batch_sw.valid_sws[i], idx:idx+forecast, :, :] = sens_temp[0].reshape(batch_sw.valid_sws[i], forecast, model.pred_size, flatten_size)
sens_var[:batch_sw.valid_sws[i], idx:idx+forecast, :, :] = sens_temp[1].reshape(batch_sw.valid_sws[i], forecast, model.pred_size, flatten_size)
if random_samples:
sens_uq[idx], eps[idx] = sens_temp[2], sens_temp[3]
else:
sens_mean[:batch_sw.valid_sws[i], idx:idx+forecast, :, :] = sens_temp.reshape(batch_sw.valid_sws[i], forecast, model.pred_size, flatten_size)
idx += forecast
if model.Pred_Type == 'Mean_Var':
prediction[:batch_sw.valid_sws[i], :,
i*forecast+size:(i+1)*forecast+size] = pred[0].detach()
var_prediction[:batch_sw.valid_sws[i], :,
i*forecast+size:(i+1)*forecast+size] = pred[1].detach()
else:
prediction[:batch_sw.valid_sws[i], :,
i*forecast+size:(i+1)*forecast+size] = pred.detach()
# cut zeros from initialisation
prediction = prediction[:, :, size:]
if model.Pred_Type == 'Mean_Var':
var_prediction = var_prediction[:, :, size:]
prediction = (prediction.cpu(), var_prediction.cpu())
if comp_sens:
if model.Pred_Type == 'Mean_Var':
sensitivities = (sens_mean.cpu(), sens_var.cpu())
else:
sensitivities = sens_mean.cpu()
if verbose:
print(f'\n{method.upper()}-based Sensitivity Analysis completed!\n')
if random_samples: # cut all lines that contain NaNs and flatten first two dimensions
sens_uq, eps = sens_uq.cpu(), eps.cpu()
sens_uq = sens_uq[~torch.isnan(sens_uq).any(dim=(1,2,3))].flatten(start_dim=0, end_dim=1)
eps = eps[~torch.isnan(eps).any(dim=(1,2))].flatten(start_dim=0, end_dim=1)
return prediction, sensitivities, sens_uq, eps
# return prediction, sensitivities
else:
return prediction
def _predict_arnn_uncertainty(model, dataloader, device='cpu', sens_params=None):
"""
Predict function for ARNN models that support uncertainty estimation
Parameters
----------
model : Model consisting of nn.Modules
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
sens_params : dict, optional
Dictionary that contains the parameters for the sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
The default is None, i.e. no sensitivity analysis is computed.
Returns
-------
if comp_sens is False:
torch.Tensor : Tensor of same langth as input, containing the predictions.
if comp_sens is True:
(torch.Tensor, dict) : Tuple of Tensor of same length as input and sensitivity dict.
Key is the prediction type, value is the sensitivity tensor.
"""
return _prediction(model, dataloader, device, sens_params)
def _predict_arnn_uncertainty_both(model, dataloader, device='cpu'):
"""
Predict function for ARNN models that support uncertainty estimation
and capture heteroscedastic and aleatoric uncertainty
Parameters
----------
model : Model consisting of nn.Modules
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
Returns
-------
if comp_sens is False:
torch.Tensor : Tensor of same langth as input, containing the predictions.
if comp_sens is True:
(torch.Tensor, dict) : Tuple of Tensor of same length as input and sensitivity dict.
Key is the prediction type, value is the sensitivity tensor.
"""
return _prediction(model, dataloader, device)
def _async_prediction(model, dataloader, device='cpu', n_samples=1, reduce=True, ensemble_weights=None, sens_params=None):
"""
Prediction of a whole Time Series with a model wrapper
in case of MVE ensemble, the weighting is done due to
https://arxiv.org/pdf/1612.01474.pdf
Parameters
----------
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
n_samples : int
Numbers of samples to take for Monte Carlo estimation. The default is 1.
reduce: bool, optional
Whether the combined uncertainty (True) or both uncertainties should be returned. The default is True.
ensemble_weights: list[dict]
List of torch state dicts containing weights. The default is None
sens_params : dict, optional
Dictionary that contains the parameters for the sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'verbose' defines whether the information about the sensitivity analysis is printed.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
The default is None, i.e. no sensitivity analysis is computed.
Returns
-------
if comp_sens is False:
torch.Tensor : Tensor of same langth as input, containing the predictions.
if comp_sens is True:
(torch.Tensor, dict) : Tuple of Tensor of same length as input and sensitivity dict.
Key is the prediction type, value is the sensitivity tensor.
"""
model.eval()
prediction_type = model.Pred_Type
if n_samples > 1:
# Enable dropout Layers during test time
for m in model.modules():
if m.__class__.__name__ == 'Dropout':
m.train()
checks = _check_sens_params_pred(sens_params)
comp_sens = checks[1] if sens_params else False # True if sens_params are given, False otherwise
verbose = checks[2] if sens_params else False
if ensemble_weights is None:
if comp_sens:
predictions = [model.prediction(dataloader, device, sens_params=sens_params) for _ in range(n_samples)]
predictions, sens_dicts = zip(*predictions)
predictions, sens_dicts = list(predictions), list(sens_dicts)
else:
predictions = [model.prediction(dataloader, device) for _ in range(n_samples)]
else:
predictions = []
sens_dicts = []
for w in ensemble_weights:
for _ in range(n_samples):
model.load_state_dict(w)
if comp_sens:
pred, sensitivity = model.prediction(dataloader, device, sens_params=sens_params)
predictions.append(pred)
sens_dicts.append(sensitivity)
else:
predictions.append(model.prediction(dataloader, device))
if verbose:
print(f'Shape of single prediction: {predictions[0].shape}')
predictions = torch.stack(predictions)
if verbose:
print(f'Shape of ensemble predictions: {predictions.shape}\n')
if comp_sens:
mean_str = 'Mean' if prediction_type == "Mean_Var" else 'Point'
means_senses = torch.stack([sens_dict[mean_str] for sens_dict in sens_dicts])
mean_sens = torch.mean(means_senses, dim=0)
if prediction_type == "Mean_Var":
var_sens = torch.stack([sens_dict['Aleatoric_UQ'] for sens_dict in sens_dicts]).mean(dim=0)
else:
var_sens = torch.var(means_senses, dim=0)
sens_avg_dict = {'Mean': mean_sens, 'Var_UQ': var_sens}
# epistemic_var = torch.var(predictions[0], dim=0) if ensemble_weights is not None or n_samples > 1 else torch.zeros(mean.shape)
if prediction_type == "Mean_Var":
mean = torch.mean(predictions[:, 0, :, :], dim=0)
total_var = torch.mean(predictions[:, 1, :, :] + predictions[:, 0, :, :].square(), dim=0) - mean.square()
aleatoric_var = torch.mean(predictions[:, 1, :, :], dim=0)
epistemic_var = total_var - aleatoric_var
predictions = (mean, total_var) if reduce else (mean, epistemic_var, aleatoric_var)
else:
mean = torch.mean(predictions, dim=0)
epistemic_var = torch.var(predictions, dim=0) if ensemble_weights is not None or n_samples > 1 else torch.zeros(mean.shape)
predictions = (mean, epistemic_var)
if comp_sens:
return predictions, sens_avg_dict
else:
return predictions
def _prediction(model, dataloader, device='cpu', sens_params=None):
"""
Prediction of a whole Time Series
Parameters
----------
dataloader : Dataloader
Dataloader to predict output
device : str, optional
device to compute on. The default is 'cpu'.
sens_params : dict, optional
Dictionary that contains the parameters for the sensitivity analysis.
Key 'method' defines the method for sensitivity analysis: 'gradient' or 'perturbation'.
Key 'comp' defines whether gradients are computed for sensitivity analysis.
Key 'plot' defines whether the results of the sensitivity analysis are visualized.
Key 'verbose' defines whether the information about the sensitivity analysis is printed.
Key 'sens_length' defines the number of randomly sampled subset of timesteps for the analysis.
(If not a multiple of the model's forecast, the number will be rounded up to the next multiple.)
The default is None, i.e. no sensitivity analysis is computed.
Returns
-------
if comp_sens is False:
torch.Tensor : Tensor of same langth as input, containing the predictions.
if comp_sens is True:
(torch.Tensor, dict) : Tuple of Tensor of same length as input and sensitivity dict.
Key is the prediction type, value is the sensitivity tensor.
"""
fc = model.forecast
prediction_type = model.Pred_Type
if prediction_type == "Quantile":
num_outputs = model.n_quantiles
else:
num_outputs = {
"Point": 1,
"Mean_Var": 2,
}
num_outputs = num_outputs[prediction_type]
if model.Ensemble:
num_outputs = 2
# The Model starts with zeros as recurrent System state
size = max(model.window_size, model.rnn_window)
prediction = torch.zeros((num_outputs, model.pred_size, size))
original_out = torch.zeros((model.pred_size, size))
x_rec = torch.zeros(1, model.pred_size, model.rnn_window)
# Tensors to device
x_rec = x_rec.to(device)
prediction = prediction.to(device)
original_out = original_out.to(device)
model.to(device)
checks = _check_sens_params_pred(sens_params)
if sens_params:
method, comp_sens, verbose, sens_length, num_samples, std_dev, correlated, random_samples, amplification = checks
else:
comp_sens, verbose = checks # False
# Initialise 3D tensor for sensitivity analysis
if comp_sens:
flatten_size = model.window_size*model.input_channels + model.rnn_window*model.pred_size
num_timesteps = len(dataloader)*fc
loader_length = num_timesteps - dataloader.dataset.subclass.add_zero
sens_indices = np.arange(len(dataloader))
if sens_length:
num_timesteps, sens_indices = _random_subset_sens_indices(sens_length, fc, model.Type, dataloader)
sens_mean = torch.full((num_timesteps, model.pred_size, flatten_size), float('nan'))
if prediction_type == "Mean_Var" or model.Ensemble:
sens_var = sens_mean.clone()
if random_samples:
sens_uq = torch.full((num_timesteps, random_samples, model.pred_size, flatten_size), float('nan'))
eps = torch.full((num_timesteps, random_samples, model.pred_size), float('nan'))
if verbose:
print(f'Start {method.upper()}-based Sensitivity Analysis...\n')
print(f'Shape of sensitivity tensor: {sens_mean.shape}')
# Iterate over dataloader
idx = 0
for i, data in enumerate(tqdm(dataloader) if verbose else dataloader):
inputs, output = data
inputs, output = inputs[0].to(device), output.to(device)
# Prepare input for model to allow autograd computing gradients of outputs w.r.t. inputs
inputs = Variable(torch.flatten(inputs, start_dim=1), requires_grad=True)
x_rec = Variable(torch.flatten(x_rec, start_dim=1), requires_grad=True)
inp = torch.cat((inputs, x_rec), dim=1)
if comp_sens:
pred = model.forward_sens(inp)
elif prediction_type == 'Mean_Var' or model.Ensemble:
pred = model.estimate_uncertainty_mean_std(inputs, x_rec)
else:
pred = model(inputs, x_rec)
if comp_sens and i in sens_indices:
sens_temp = _comp_sensitivity(method, model, inp, pred, num_samples, std_dev, correlated, random_samples, amplification)
if prediction_type == "Mean_Var" or model.Ensemble:
sens_mean[idx:idx+fc], sens_var[idx:idx+fc] = sens_temp[0], sens_temp[1]
if random_samples:
sens_uq[idx], eps[idx] = sens_temp[2], sens_temp[3]
else:
sens_mean[idx:idx+fc] = sens_temp
idx += fc
if type(pred) == tuple:
pred = torch.vstack(pred)
if prediction_type == "Quantile":
# QR uses a different output shape [1,pred_size,forecast,num_outputs] than other models [num_outputs,pred_size,forecast]
# We can adjust for this by squeezing the first and swapping the remaining axes
pred = pred.squeeze(0)
pred = torch.transpose(pred, 0, 2)
pred = torch.transpose(pred, 1, 2)
prediction = torch.cat((prediction,
pred.detach().reshape(num_outputs, model.pred_size, -1)), -1)
x_rec = torch.unsqueeze(prediction[0, :, -model.rnn_window:], dim=0) # only feed back the mean values!
original_out = torch.cat((original_out,
output.reshape(model.pred_size, -1)), 1)
# cut zeros from initialisation
# shape = [num_outputs, pred_size, len(dataset)] = [num_outputs, pred_size, forecast*len(dataloader)]
prediction = prediction[..., size:]
original_out = original_out[:, size:]
cut_zeros = dataloader.dataset.subclass.add_zero
if cut_zeros != 0:
prediction = prediction[..., :-cut_zeros]
original_out = original_out[:, :-cut_zeros]
prediction.cpu()
model.to('cpu')
if comp_sens:
# cut to the same length as the prediction
sens_mean = sens_mean[:loader_length]
if prediction_type == "Mean_Var" or model.Ensemble:
sens_var = sens_var[:loader_length]
sensitivity_dict = {'Mean': sens_mean.cpu(), 'Aleatoric_UQ': sens_var.cpu()}
else:
sensitivity_dict = {'Point': sens_mean.cpu()}
if verbose:
print(f'{method.upper()}-based Sensitivity Analysis completed!\n')
if random_samples: # cut all lines that contain NaNs and flatten first two dimensions
if verbose:
mean, std = prediction[0].mean(dim=-1), torch.sqrt(prediction[1].mean(dim=-1))
print(f'Mean of predictions: {torch.round(mean, decimals=3)}, Std of predictions: {torch.round(std, decimals=4)}')
sens_uq, eps = sens_uq.cpu(), eps.cpu()
sens_uq = sens_uq[~torch.isnan(sens_uq).any(dim=(1,2,3))].flatten(start_dim=0, end_dim=1)
eps = eps[~torch.isnan(eps).any(dim=(1,2))].flatten(start_dim=0, end_dim=1)
return prediction, sensitivity_dict, sens_uq, eps
return prediction, sensitivity_dict
else:
return prediction
def _comp_perturb_sens(model, inputs, pred, perturb_size=10, std_dev=0.2, correlated=True, random_samples=0, amplification=1):
"""
Compute the perturbation-based sensitivity of the output w.r.t. the inputs, also known as
Permutation Feature Importance (PFI). The method is based on the paper by Altmann et al. (2010).
Parameters
----------
model : Model consisting of nn.Modules
inputs : torch.Tensor
Input tensor with already concatenated external excitation and recurrent state signals,
shape=[batch_size, input_channels*window_size + pred_size*rnn_window]
pred : torch.Tensor
Output tensor that contains the predictions,
shape=[batch_size, pred_size, forecast]
output : torch.Tensor
Output tensor that contains the true values,
shape=[batch_size, pred_size, forecast]
perturb_size : int, optional
Number of permutations per input feature. The default is 4.
std_dev : float, optional
Standard deviation of the Gaussian noise as form of permutation.
The default is 0.2.
correlated : bool, optional
If True, the relative perturbations are decayed in a local area of the current feature
to be perturbed, based on the strongest region of the signal's auto-correlation.
The default is True, i.e. the perturbations are decayed.
random_samples : int, optional
Number of random samples to take for "neural" Monte Carlo uncertainty estimation for MVE models.
The default is 0, i.e. no random samples are taken.
Returns
-------
sens_temp : torch.Tensor
Sensitivity tensor as Jacobian that contains the gradients of the output w.r.t. the inputs,
shape=[batch_size*forecast, pred_size, input_channels*window_size + pred_size*rnn_window]
Raises
------
AssertionError
If the perturb_size is less than 1 or the std_dev is less than 0.
"""
assert perturb_size > 0, 'Permutation size must me greater than 0 for PFI!'
assert std_dev > 0, 'Standard deviation for Gaussian noise must me greater than 0 for PFI!'
if not random_samples:
torch.manual_seed(42)
steps = torch.normal(mean=0, std=std_dev, size=(perturb_size, 1))
while (steps > 0).sum() != perturb_size//2: # ensure that positive and negative perturbations are equal
steps = torch.normal(mean=0, std=std_dev, size=(perturb_size, 1))
batch_size, pred_size, forecast = pred[0].shape if (model.Pred_Type == 'Mean_Var' or model.Ensemble) else pred.shape
flatten_size = inputs.shape[1]
input_channels, window_size = model.input_channels, model.window_size
rec_start_idx = input_channels * window_size
rnn_window = model.rnn_window if model.Type in ['AR', 'AR_RNN'] else 0
ch_size = (input_channels + pred_size) if model.Type in ['AR', 'AR_RNN'] else input_channels
def auto_correlation(signal):
"""Compute the auto-correlation of a channel signal"""
signal = signal.detach().numpy()
signal = signal - np.mean(signal) # zero-mean to avoid bias and make intervals more comparable
auto_cov = np.correlate(signal, signal, mode='full')
return auto_cov / (np.max(auto_cov) + 1e-6)
def get_autocorrelations(inputs):
"""Get the auto-correlations of input and recurrent signals."""
input_ = inputs[:,:rec_start_idx].reshape(input_channels, window_size)
if model.Type in ['AR', 'AR_RNN']:
x_rec = inputs[:,rec_start_idx:].reshape(pred_size, rnn_window)
auto_corrs = []
for ch_idx in range(ch_size):
if ch_idx < input_channels:
auto_cor = auto_correlation(input_[ch_idx])
else:
auto_cor = auto_correlation(x_rec[ch_idx-input_channels])
auto_corrs.append(torch.tensor(auto_cor))
return auto_corrs
def get_decay(auto_corrs, idx):
"""
Get the decay values from the auto-correlations, including the indices from
the slicing operation at the current position within the sliding window.
"""
if model.Type in ['AR', 'AR_RNN']:
ch = idx//window_size if idx < rec_start_idx else (idx-rec_start_idx)//rnn_window + input_channels
pos = idx % window_size if idx < rec_start_idx else (idx-rec_start_idx) % rnn_window
else:
ch, pos = idx // window_size, idx % window_size
win_size = window_size if ch < input_channels else rnn_window
if correlated:
auto_cor = auto_corrs[ch][win_size-1-pos : 2*win_size-1-pos]
zero_intersecs = np.where(np.diff(np.sign(auto_cor)) != 0)[0]
p = np.min(np.abs(zero_intersecs - pos)) if len(zero_intersecs) > 0 else win_size-idx
start, stop = max(0, pos-p), min(pos+p+1, win_size)
decay = auto_cor[start:stop]
else:
decay, (start, stop) = torch.ones(1), (pos, pos+1)
return decay, (start, stop), ch, win_size
def perturb_input(batched_input, decay, win_size, ch, start, stop):
"""
Perturb the input signal with Gaussian noise, over the length of the
auto-correlation period of the signal (until first intersection with x-axis).
"""
if ch < input_channels: # within input channels
input_ = batched_input[:,:rec_start_idx].reshape(-1, input_channels, win_size)
input_[:, ch, start:stop] *= decay
return torch.cat((input_.flatten(start_dim=1), batched_input[:,rec_start_idx:]), dim=1)
else: # within recurrent channels
x_rec = batched_input[:,rec_start_idx:].reshape(-1, pred_size, win_size)
x_rec[:, ch-input_channels, start:stop] *= decay
return torch.cat((batched_input[:,:rec_start_idx], x_rec.flatten(start_dim=1)), dim=1)
def perturbation(model, inputs, pred, steps):
"""
Compute the RMS score of differences, coming from all perturbations
of the inputs w.r.t. the reference output for one timestep.
"""
sens_temp = torch.full((batch_size*forecast, pred_size, flatten_size), float('nan'))
for i in np.arange(batch_size):
batched_inp = inputs[i:i+1].repeat(perturb_size, 1) # shape = [perturb_size, flatten_size]
auto_corrs = get_autocorrelations(inputs[i:i+1]) if correlated else torch.ones(ch_size)
for j in np.arange(flatten_size):
decay, start_stop, ch, win_size = get_decay(auto_corrs, j)
decay = 1 + steps * decay
perturbed_inp = batched_inp.clone()
perturbed_inp = perturb_input(perturbed_inp, decay, win_size, ch, *start_stop)
perturbed_pred = model.forward_sens(perturbed_inp) # shape = [perturb_size, pred_size, forecast]
del perturbed_inp
# compute a sensitivity metric of all perturbations in each feature variation
if model.Pred_Type == 'Mean_Var' or model.Ensemble:
differences = (perturbed_pred[0] - pred[i:i+1,...]).detach()
else:
differences = (perturbed_pred - pred[i:i+1,...]).detach()
differences = torch.mean(differences, dim=0)
sens_temp[i*forecast:(i+1)*forecast,:,j] = differences.T # shape = [forecast, pred_size]
return sens_temp
# Compute and return the differences from the perturbations
if model.Pred_Type == 'Point':
return perturbation(model, inputs, pred, steps)
elif model.Pred_Type == 'Mean_Var' or model.Ensemble:
mean_pred, var_pred = pred
sens_temp_mean = perturbation(model, inputs, mean_pred, steps)
sens_temp_var = torch.zeros_like(sens_temp_mean)
if random_samples:
permutations = torch.full((random_samples, batch_size*forecast, pred_size, flatten_size), float('nan'))
samples = torch.full((random_samples, pred_size), float('nan'))
for i in range(random_samples): # not direct sampling, but reparametrization trick: mean + eps*std, with eps ~ N(0,1)
eps = torch.randn(1, pred_size, 1) * amplification # used for later up-scaling
sampled_pred = mean_pred + eps * torch.sqrt(var_pred)
permutations[i] = perturbation(model, inputs, sampled_pred[0], steps)
samples[i] = eps[0,:,0].detach()
sens_temp_var = permutations.mean(dim=0)
return sens_temp_mean, sens_temp_var, permutations.mean(dim=1), samples
return sens_temp_mean, sens_temp_var
def _comp_sensitivity(method, model, inp, pred, num_samples=10, std_dev=0.2, correlated=True, random_samples=0, amplification=1):
"""
Abatracted mathod that computes the sensitivity analysis for each step in the dataloader
based on the given method.
Parameters
----------
method : The method to use for sensitivity analysis.
model : Model consisting of nn.Modules
inputs : torch.Tensor
Input tensor with already concatenated external excitation and recurrent state signals,
shape=[batch_size, input_channels*window_size + pred_size*rnn_window]
pred : torch.Tensor
Output tensor that contains the predictions,
shape=[batch_size, pred_size, forecast]
num_samples : int, optional
Number of permutations per input feature. The default is 4.
std_dev : float, optional
Standard deviation of the Gaussian noise as form of permutation.
The default is 0.2.
correlated : bool, optional
If True, the perturbations are based on the auto-correlation of the signals.
The default is True.
random_samples : int, optional
Number of random samples, drawn from a standard normal distribution, to approximate the
expected sensitivity range across the aleatoric uncertainty of MVE models.
The default is 0 samples, i.e. no sampling is performed.
amplification : float, optional
Factor to amplify the sensitivity gradients when performing the sensitivity analysis
under re-sampling for MVE models. The default is 1, i.e. no amplification.
Returns
-------
The computed sensitivity analysis result. In case of an MVE model, the result is a tuple,
containing the mean and variance of the sensitivity analysis.
Raises
------
ValueError: If an invalid method is given for the sensitivity analysis.
"""
method = method.lower()
if method == SensitivityMethods.GRADIENT.value:
return _comp_grad_sens(inp, pred, model.Pred_Type, model.Ensemble, random_samples, amplification)
elif method == SensitivityMethods.SMOOTH_GRAD.value:
return _comp_smooth_grad_sens(model, inp, pred, model.Pred_Type, model.Ensemble, num_samples, std_dev)
elif method == SensitivityMethods.INTEGRATED_GRADIENT.value:
return _comp_integrated_grad_sens(model, inp, pred, model.Pred_Type, model.Ensemble, num_samples)
elif method == SensitivityMethods.PERTURBATION.value:
return _comp_perturb_sens(model, inp, pred, num_samples, std_dev, correlated, random_samples, amplification)
else:
raise ValueError((f"Given method '{method}' is not implemented! Choose from: "
f"{[x.lower() for x in list(SensitivityMethods.__members__)]}"))
def _check_sens_params_pred(sens_params):
"""
Check the sensitivity parameters for all AR and RNN models.
Parameters
----------
sens_params : dict
Dictionary that contains the parameters for the sensitivity analysis.
Returns
-------
if sens_params is given / not None:
tuple : method, comp_sens, verbose, sens_length, num_samples, std_dev,
correlated, random_samples, amplification
elif sens_params is None:
bool : comp_sens
"""
if sens_params:
method = sens_params.get('method', '')
comp_sens = sens_params.get('comp', False)
verbose = sens_params.get('verbose', False)
sens_length = sens_params.get('sens_length', None)
num_samples = sens_params.get('num_samples', 10)
std_dev = sens_params.get('std_dev', 0.2)
correlated = sens_params.get('correlated', True)
random_samples = sens_params.get('random_samples', 0)
amplification = sens_params.get('amplification', 1)
return method, comp_sens, verbose, sens_length, num_samples, std_dev, correlated, random_samples, amplification
else:
comp_sens = False
verbose = False
return comp_sens, verbose
def _random_subset_sens_indices(sens_length, forecast, m_type, dataloader, batched=False):
"""
Create random indices for sensitivity analysis that
are a subset of the dataloader indices, allowing for faster prediction.
Parameters
----------
sens_length : int
Desired size for the reduced points that the sensitivity analysis is computed on
forecast : int
The model's forecast length
m_type : str
The model type, specified as object attribute in "model.Type"
dataloader : Dataloader
Dataloader for the test dataset to predict the output for.
batched : bool, optional
If True, the dataloader is batched, i.e. only used for "_pred_ARNN_batch" function.
The default is False.
Returns
-------
num_timesteps : int
Number of timesteps for sensitivity analysis
sens_indices : list
List of indices that are used for the sensitivity analysis
Raises
------
AssertionError
If the given sensitivity length is smaller than forecast length times batch size.
AssertionError
If the given sensitivity length exceeds the maximum length of the dataloader.
"""
sens_length = int(sens_length)
batch_size = next(iter(dataloader))[0].shape[0] if m_type == 'RNN' else 1
fc, bs = forecast, batch_size
assert sens_length >= fc*bs, f'Given sensitivity length of {sens_length} must be at least of size {fc*bs}!'
if not batched:
num_timesteps = len(dataloader.dataset)*fc
loader_length = num_timesteps - dataloader.dataset.subclass.add_zero
assert sens_length <= loader_length, f'Given sensitivity length of {sens_length} exceeds maximum dataloader length of {loader_length}!'
else:
num_timesteps = len(dataloader)*forecast
assert sens_length <= num_timesteps, f'Given sensitivity length of {sens_length} exceeds maximum dataloader length of {num_timesteps}!'
add_zeros = [sw.add_zero for sw in dataloader.sws]
min_length = min([le*forecast-zeros for le, zeros in zip(dataloader.__lengths__(), add_zeros)])
# round sens_length to the closest multiple of forecast*batch_size if needed
if sens_length % (fc*bs) != 0:
sens_length = np.around(sens_length/(fc*bs)) if m_type == 'RNN' else np.ceil(sens_length/(fc*bs))
sens_length = int(sens_length * fc*bs)
print(f'INFO: Given sensitivity length was rounded to {sens_length} as closest multiple of batch_size={bs} * forecast={fc}.')
num_timesteps = sens_length
sampler = LatinHypercube(d=1)
samples = sampler.random(n=sens_length//(fc*bs)).flatten()
# fill up sens_indices in case of duplications until desired sens_length is reached
min_length = len(dataloader) if not batched else min_length
sens_indices = np.floor(samples * min_length).astype(int)
unique_indices = set(sens_indices)
while len(unique_indices) < (sens_length//(fc*bs)):
additional_samples = sampler.random(n=(sens_length//(fc*bs)) - len(unique_indices)).flatten()
additional_indices = np.floor(additional_samples * min_length).astype(int)
unique_indices.update(additional_indices)
sens_indices = sorted(unique_indices)
return num_timesteps, sens_indices
def _compress_model(old_model, state_dict_path, new_window_size=None, new_rnn_window=None, retrain=False, retrain_params=None):
"""
Compress the first layer of an ARNN model by reducing the input and recurrent window sizes.
Additional option to fine-tune the compressed model after compression.
INFO: Currently NOT working for ensemble and RNN models!
Parameters
----------
old_model : nn.Module
The old NN model to compress
state_dict_path : str
Path to the saved state dict of the old model
new_window_size : int
New window size for the input features
new_rnn_window : int
New window size for the recurrent features
retrain : bool, optional
If True, fine-tune the new model after compression. The default is False.
data_handle : data_handle object, optional
Object from Meas_handling module to load and prepare the data for retraining.
Only needed if retrain=True. The default is None, i.e. not provided.
Returns
-------
new_model : nn.Module
The compressed model with reduced input and recurrent window sizes.
"""
# Get the name of all first weight tensors
weight_names = []
for name, _ in old_model.named_parameters():
if '0.weight' in name:
weight_names.append(name)
# Load the saved weights
state_dict = torch.load(state_dict_path)
# Extract and reduce the first weight matrix
for w_name in weight_names:
w_mat = state_dict[w_name]
rec_start_idx = old_model.input_channels * old_model.window_size
if new_window_size:
assert new_window_size <= old_model.window_size, 'New window_size must be smaller or equal than the old one!'
temp_inp = w_mat[:, :rec_start_idx].view(w_mat.shape[0], -1, old_model.window_size)
temp_inp = temp_inp[...,-new_window_size:].flatten(start_dim=1) # take only the _last_ new_window_size columns
else:
new_window_size = old_model.window_size
temp_inp = w_mat[:, :rec_start_idx]
if new_rnn_window:
assert new_rnn_window <= old_model.rnn_window, 'New rnn_window_size must be smaller or equal than the old one!'
temp_rec = w_mat[:, rec_start_idx:].view(w_mat.shape[0], -1, old_model.rnn_window)
temp_rec = temp_rec[...,-new_rnn_window:].flatten(start_dim=1)
else:
new_rnn_window = old_model.rnn_window
temp_rec = w_mat[:, rec_start_idx:]
state_dict[w_name] = torch.cat((temp_inp, temp_rec), dim=1)
new_params = {'window_size': new_window_size, 'rnn_window': new_rnn_window}
# Extract __init__ parameters and create a new model with modified window sizes
cls = old_model.__class__
parameters = inspect.signature(cls.__init__).parameters
init_params = {param: getattr(old_model, param) for param in parameters if param != 'self' and hasattr(old_model, param)}
model_temp = old_model.DNN if hasattr(old_model, 'DNN') else old_model
parameters = inspect.signature(model_temp.__class__.__init__).parameters
wrong_keys = ['input_size', 'output_size']
init_params_dnn = {param: getattr(model_temp, param) for param in parameters if param != 'self'
and param not in wrong_keys and hasattr(model_temp, param)}
init_params.update(init_params_dnn)
if cls.__name__ == 'SeparateMVEARNN':
init_params['var_hidden_size'] = init_params.pop('hidden_size')
for name, value in init_params.items():
wrapped_model = value if isinstance(value, nn.Module) else None
if wrapped_model:
wrapped_params = inspect.signature(wrapped_model.__init__).parameters
wrapped_init_params = {param: getattr(wrapped_model, param) for param in wrapped_params \
if param != 'self' and hasattr(wrapped_model, param)}
wrapped_init_params.update({k: v for k, v in new_params.items() if k in wrapped_init_params})
model_temp = wrapped_model.DNN if hasattr(wrapped_model, 'DNN') else wrapped_model
parameters = inspect.signature(model_temp.__class__.__init__).parameters
init_params_dnn = {param: getattr(model_temp, param) for param in parameters if param != 'self'
and param not in wrong_keys and hasattr(model_temp, param)}
wrapped_init_params.update(init_params_dnn)
new_wrapped_model = wrapped_model.__class__(**wrapped_init_params)
init_params[name] = new_wrapped_model
# print('wrapped_init_params:', wrapped_init_params, '\n')
break
if 'window_size' in init_params: # indicates that the outer model is of type nn.Module
init_params.update({k: v for k, v in new_params.items() if k in init_params})
# print('init_params:', init_params, '\n')
# Instantiate a new model with updated parameters and load the modified state dict
compressed_model = cls(**init_params)
compressed_model.load_state_dict(state_dict)
params_old = summary(old_model).total_params
params_new = summary(compressed_model).total_params
print((f'Reduction by {params_old - params_new} parameters in the Input layer, resulting '
f'in a total model compression ratio of {(1-params_new/params_old):.1%}\n'))
if retrain and retrain_params:
lr = retrain_params.get('lr', 1e-4) / 2 # reduce learning rate for fine-tuning
patience = retrain_params.get('patience', 5)
max_epochs = retrain_params.get('max_epochs', 100)
stab = retrain_params.get('stabilizer', 5e-3)
data_handle = retrain_params.get('data_handle', None)
if data_handle is None:
raise ValueError('No data_handle object provided for re-training the model!')
## V1: fine-tuning the compressed model
train_loader, val_loader = data_handle.give_torch_loader(window_size=new_window_size,
rnn_window=new_rnn_window, keyword='short')
opt = torch.optim.Adam(compressed_model.parameters(), lr=lr)
crit = nn.MSELoss() if compressed_model.Pred_Type == 'Point' else GaussianNLLLoss()
print((f'Start fine-tuning the compressed model with lr={lr:.2e}, patience={patience}, '
f'max_epochs={max_epochs} and stab={stab:.2e} ...'))
res_df = train_model(model=compressed_model, train_loader=train_loader, max_epochs=max_epochs, optimizer=opt, device='cpu',
criterion=crit, stabelizer=stab, val_loader=val_loader, patience=patience, print_results=True)
print('Fine-tuning finished!\n')
if wrapped_model:
return compressed_model, res_df, init_params, wrapped_init_params
else:
return compressed_model, res_df, init_params
if wrapped_model:
return compressed_model, init_params, wrapped_init_params
else:
return compressed_model, init_params