Arcana Plots

Submodules

arcana.plots.analysis_plot_utils module

This module contains the functions to plot the attention matrix for each head and the sensitivity analysis of the model.

class arcana.plots.analysis_plot_utils.AnalysisPlotUtils(arcana_procedure, sample_number)

Bases: object

AnalysisPlotUtils class to plot the attention matrix for each head and the sensitivity analysis of the model.

plot_additive_attention(attention_probs, arch_pointer)

Plots the overall performance of the additive system on each sequence.

Parameters:
  • attention_probs (tensor or np.array) – Attention probabilities

  • arch_pointer (str) – The architecture pointer

plot_all_multihead_attention(attention_probs, arch_pointer='encoder', batch_idx=0)

Plots the attention matrix for each head and the overall performance of the model on each sequence.

Parameters:
  • attention_probs (_type_) – _description_

  • arch_pointer (str, optional) – _description_. Defaults to “encoder”.

  • batch_idx (int, optional) – _description_. Defaults to 0.

plot_sensitivity_analyis(sensitivity, future_step, available_sequence, log_scale=False)

Plots the sensitivity analysis of the model

Parameters:
  • sensitivity (tensor or np.array) – Sensitivity analysis

  • future_step (int) – The future step

  • save_path (str) – The path to save the plot

class arcana.plots.analysis_plot_utils.ScalarFormatterForceFormat(useOffset=None, useMathText=None, useLocale=None)

Bases: ScalarFormatter

ScalarFormatterForceFormat class to format the colorbar labels to scientific notation. Inherits from ScalarFormatter class from matplotlib.ticker.

Parameters:

ticker (matplotlib.ticker.ScalarFormatter) – ScalarFormatter class from matplotlib.ticker

arcana.plots.plot_utils module

This module contains the functions to plot the learning rate, the train and validation loss and the prediction of the model

arcana.plots.plot_utils.plot_model_learning_rate(learning_rate_dict, plot_path)

plot the learning rate of the model

Parameters:
  • learning_rate_dict (dict) – a dictionary containing the learning rate and epoch

  • plot_path (str) – the path to save the plot

arcana.plots.plot_utils.plot_sample_prediction(original_sequence, mean_prediction, available_sequence, plot_path, loss_type, scores_prediction, sample_number, upper_prediction, lower_prediction, ylabels=None)

plot the prediction of the model

Parameters:
  • original_sequence (np.array) – the original sequence

  • mean_prediction (np.array) – the mean prediction of the model

  • available_sequence (int) – the length of the available sequence

  • plot_path (str) – the path to save the plot

  • loss_type (str) – the type of loss to plot

  • scores_prediction (dict) – a dictionary containing the scores of the prediction

  • uncertainty (bool) – whether to plot the uncertainty or not

  • random_index (int) – the random index of the sequence

  • upper_prediction (np.array, optional) – the upper prediction of the model. Defaults to None.

  • lower_prediction (np.array, optional) – the lower prediction of the model. Defaults to None.

  • ylabels (list, optional) – the labels of the y axis. Defaults to None.

arcana.plots.plot_utils.plot_train_val_loss(losses, plot_path, loss_type, train_loss_mode='batch')

plot the train and validation loss for epoch or batch

Parameters:
  • losses (dict) – a dictionary containing the train and validation loss

  • plot_path (str) – the path to save the plot

  • loss_type (str) – the type of loss to plot

  • train_loss_mode (str, optional) – the type of loss to plot. Defaults to “batch”.

arcana.plots.plot_utils.plot_train_val_loss_individual(losses, plot_path, loss_type, train_loss_mode)

plot the train and validation loss for epoch or batch

Parameters:
  • losses (dict) – a dictionary containing the train and validation loss

  • plot_path (str) – the path to save the plot

  • loss_type (str) – the type of loss to plot

  • train_loss_mode (str, optional) – the type of dimension to plot.

arcana.plots.plots module

The general plottings with scientific format are described.

class arcana.plots.plots.Plots

Bases: object

_General class for multipurpose plotting

plot_line(ax, x_values, y_values, color, legend_name=None)

_General function for plotting line plots :param ax: matplotlib axes :type ax: matplotlib.axes._subplots.AxesSubplot :param x_values: x values :type x_values: np.array :param y_values: y values :type y_values: np.array :param legend_name: legend name. Defaults to None. :type legend_name: str :param color: color :type color: str

plot_scatter(ax, x_values, y_values, color, legend_name=None, size=3)

_General function for plotting scatter plots :param ax: matplotlib axes :type ax: matplotlib.axes._subplots.AxesSubplot :param x_values: x values :type x_values: np.array :param y_values: y values :type y_values: np.array :param legend_name: legend name. Defaults to None. :type legend_name: str :param color: color :type color: str :param size: size. Defaults to 3. :type size: int

set_labels(ax, x_label, y_label)

_General function for setting labels :param ax: matplotlib axes :type ax: matplotlib.axes._subplots.AxesSubplot :param x_label: x label :type x_label: str :param y_label: y label :type y_label: str

Module contents