Arcana Training

Submodules

arcana.training.train_model module

This script contains the training function for the sequence to sequence model

class arcana.training.train_model.Seq2SeqTrainer(seq2seq, criterion, optimizer, device, scheduler, config)

Bases: object

Class for training the sequence to sequence model

calculate_overall_combined_loss(target, output, process_type, temp_losses_individual)

calculating the loss function for each dimension and overall loss - calculate the loss for endogenous variables - mask-out the zero values - calculate losses and store them in the appropriate dictionary

Parameters:
  • target (tensor) – target tensor

  • output (tensor) – predicted tensor

  • process_type (str) – train or val

  • temp_losses_individual (dict) – dictionary for storing the losses for each dimension

Returns:
  • loss – overall loss

  • temp_losses_individual – loss for each dimension

calculate_overall_loss(target, output, dim_weights)

calculating the loss function for each dimension and overall loss

Parameters:
  • target (tensor) – target tensor

  • output (tensor) – predicted tensor

  • dim_weights (list) – weights for each dimension

Returns:
  • loss – overall loss

  • temp_losses_individual – loss for each dimension

compute_loss(output, target, process_type)

Compute the loss for each batch - calculate the loss for each dimension and overall loss - add the losses for later averaging for each epoch

Parameters:
  • output (tensor) – predicted tensor

  • target (tensor) – target tensor

  • process_type (str) – train or val

Returns:

loss – overall loss

count_parameters()

Count the number of trainable parameters

early_stopping_check(train_loss, val_loss, epoch)

Check if early stopping should be applied

Parameters:
  • train_loss (float) – training loss

  • val_loss (float) – validation loss

  • epoch (int) – current epoch

Returns:

should_stop (bool) – True if early stopping should be applied

initialize_training()

Initialize different parameters for training the model: - create directories for saving the model and parameters - move the model to the device - initialize early stopping and teacher forcing

plot_training_params()

Plot the learning rate and losses

save_training_results_and_plots(epoch=None)

Save the model and parameters and plot the results

Parameters:

epoch (int) – current epoch

train_epoch(epoch, train_loader, available_seq)

Train the model for one epoch by looping through the batches - zero the hidden states and gradients - forward pass - compute loss and backpropagation - update parameters and learning rate - calculate the loss for each batch and add it to the loss_trace

Parameters:
  • epoch (int) – current epoch

  • train_loader (torch.utils.data.DataLoader) – training data loader

  • available_seq (int) – available sequence length

train_model(train_loader, val_loader, val_lengths, trial=None)

The main function that controls the training process which does the following: - initialize the training - train the model - validate the model - calculate the loss for each epoch and add it to the loss_trace - print the last losses and scores after every 50 epochs - early stopping - update the training parameters - save the training results and plots

Parameters:
  • train_loader (torch.utils.data.DataLoader) – training data loader

  • val_loader (torch.utils.data.DataLoader) – validation data loader

  • val_lengths (list) – list of lengths of the validation data

  • trial (optuna.trial) – optuna trial

update_training_params(epoch)

Update the learning rate and add it to the learning_rate_dict

Parameters:

epoch (int) – current epoch

validation_epoch(val_loader, val_lengths, available_seq)

Validate the model for one epoch by looping through the batches - turn off the gradients - zero the hidden states - forward pass and compute loss - calculate the loss for each batch and add it to the loss_trace

Parameters:
  • val_loader (torch.utils.data.DataLoader) – validation data loader

  • val_lengths (list) – list of lengths of the validation data

  • available_seq (int) – available sequence length

Module contents