Arcana Plots
Submodules
arcana.plots.analysis_plot_utils module
This module contains the functions to plot the attention matrix for each head and the sensitivity analysis of the model.
- class arcana.plots.analysis_plot_utils.AnalysisPlotUtils(arcana_procedure, sample_number)
Bases:
object
AnalysisPlotUtils class to plot the attention matrix for each head and the sensitivity analysis of the model.
- plot_additive_attention(attention_probs, arch_pointer)
Plots the overall performance of the additive system on each sequence.
- Parameters:
attention_probs (tensor or np.array) – Attention probabilities
arch_pointer (str) – The architecture pointer
- plot_all_multihead_attention(attention_probs, arch_pointer='encoder', batch_idx=0)
Plots the attention matrix for each head and the overall performance of the model on each sequence.
- Parameters:
attention_probs (_type_) – _description_
arch_pointer (str, optional) – _description_. Defaults to “encoder”.
batch_idx (int, optional) – _description_. Defaults to 0.
- plot_sensitivity_analyis(sensitivity, future_step, available_sequence, log_scale=False)
Plots the sensitivity analysis of the model
- Parameters:
sensitivity (tensor or np.array) – Sensitivity analysis
future_step (int) – The future step
save_path (str) – The path to save the plot
- class arcana.plots.analysis_plot_utils.ScalarFormatterForceFormat(useOffset=None, useMathText=None, useLocale=None)
Bases:
ScalarFormatter
ScalarFormatterForceFormat class to format the colorbar labels to scientific notation. Inherits from ScalarFormatter class from matplotlib.ticker.
- Parameters:
ticker (matplotlib.ticker.ScalarFormatter) – ScalarFormatter class from matplotlib.ticker
arcana.plots.plot_utils module
This module contains the functions to plot the learning rate, the train and validation loss and the prediction of the model
- arcana.plots.plot_utils.plot_model_learning_rate(learning_rate_dict, plot_path)
plot the learning rate of the model
- Parameters:
learning_rate_dict (dict) – a dictionary containing the learning rate and epoch
plot_path (str) – the path to save the plot
- arcana.plots.plot_utils.plot_sample_prediction(original_sequence, mean_prediction, available_sequence, plot_path, loss_type, scores_prediction, sample_number, upper_prediction, lower_prediction, ylabels=None)
plot the prediction of the model
- Parameters:
original_sequence (np.array) – the original sequence
mean_prediction (np.array) – the mean prediction of the model
available_sequence (int) – the length of the available sequence
plot_path (str) – the path to save the plot
loss_type (str) – the type of loss to plot
scores_prediction (dict) – a dictionary containing the scores of the prediction
uncertainty (bool) – whether to plot the uncertainty or not
random_index (int) – the random index of the sequence
upper_prediction (np.array, optional) – the upper prediction of the model. Defaults to None.
lower_prediction (np.array, optional) – the lower prediction of the model. Defaults to None.
ylabels (list, optional) – the labels of the y axis. Defaults to None.
- arcana.plots.plot_utils.plot_train_val_loss(losses, plot_path, loss_type, train_loss_mode='batch')
plot the train and validation loss for epoch or batch
- Parameters:
losses (dict) – a dictionary containing the train and validation loss
plot_path (str) – the path to save the plot
loss_type (str) – the type of loss to plot
train_loss_mode (str, optional) – the type of loss to plot. Defaults to “batch”.
- arcana.plots.plot_utils.plot_train_val_loss_individual(losses, plot_path, loss_type, train_loss_mode)
plot the train and validation loss for epoch or batch
- Parameters:
losses (dict) – a dictionary containing the train and validation loss
plot_path (str) – the path to save the plot
loss_type (str) – the type of loss to plot
train_loss_mode (str, optional) – the type of dimension to plot.
arcana.plots.plots module
The general plottings with scientific format are described.
- class arcana.plots.plots.Plots
Bases:
object
_General class for multipurpose plotting
- plot_line(ax, x_values, y_values, color, legend_name=None)
_General function for plotting line plots :param ax: matplotlib axes :type ax: matplotlib.axes._subplots.AxesSubplot :param x_values: x values :type x_values: np.array :param y_values: y values :type y_values: np.array :param legend_name: legend name. Defaults to None. :type legend_name: str :param color: color :type color: str
- plot_scatter(ax, x_values, y_values, color, legend_name=None, size=3)
_General function for plotting scatter plots :param ax: matplotlib axes :type ax: matplotlib.axes._subplots.AxesSubplot :param x_values: x values :type x_values: np.array :param y_values: y values :type y_values: np.array :param legend_name: legend name. Defaults to None. :type legend_name: str :param color: color :type color: str :param size: size. Defaults to 3. :type size: int
- set_labels(ax, x_label, y_label)
_General function for setting labels :param ax: matplotlib axes :type ax: matplotlib.axes._subplots.AxesSubplot :param x_label: x label :type x_label: str :param y_label: y label :type y_label: str