Arcana Training
Submodules
arcana.training.train_model module
This script contains the training function for the sequence to sequence model
- class arcana.training.train_model.Seq2SeqTrainer(seq2seq, criterion, optimizer, device, scheduler, config)
Bases:
object
Class for training the sequence to sequence model
- calculate_overall_combined_loss(target, output, process_type, temp_losses_individual)
calculating the loss function for each dimension and overall loss - calculate the loss for endogenous variables - mask-out the zero values - calculate losses and store them in the appropriate dictionary
- Parameters:
target (tensor) – target tensor
output (tensor) – predicted tensor
process_type (str) – train or val
temp_losses_individual (dict) – dictionary for storing the losses for each dimension
- Returns:
loss – overall loss
temp_losses_individual – loss for each dimension
- calculate_overall_loss(target, output, dim_weights)
calculating the loss function for each dimension and overall loss
- Parameters:
target (tensor) – target tensor
output (tensor) – predicted tensor
dim_weights (list) – weights for each dimension
- Returns:
loss – overall loss
temp_losses_individual – loss for each dimension
- compute_loss(output, target, process_type)
Compute the loss for each batch - calculate the loss for each dimension and overall loss - add the losses for later averaging for each epoch
- Parameters:
output (tensor) – predicted tensor
target (tensor) – target tensor
process_type (str) – train or val
- Returns:
loss – overall loss
- count_parameters()
Count the number of trainable parameters
- early_stopping_check(train_loss, val_loss, epoch)
Check if early stopping should be applied
- Parameters:
train_loss (float) – training loss
val_loss (float) – validation loss
epoch (int) – current epoch
- Returns:
should_stop (bool) – True if early stopping should be applied
- initialize_training()
Initialize different parameters for training the model: - create directories for saving the model and parameters - move the model to the device - initialize early stopping and teacher forcing
- plot_training_params()
Plot the learning rate and losses
- save_training_results_and_plots(epoch=None)
Save the model and parameters and plot the results
- Parameters:
epoch (int) – current epoch
- train_epoch(epoch, train_loader, available_seq)
Train the model for one epoch by looping through the batches - zero the hidden states and gradients - forward pass - compute loss and backpropagation - update parameters and learning rate - calculate the loss for each batch and add it to the loss_trace
- Parameters:
epoch (int) – current epoch
train_loader (torch.utils.data.DataLoader) – training data loader
available_seq (int) – available sequence length
- train_model(train_loader, val_loader, val_lengths, trial=None)
The main function that controls the training process which does the following: - initialize the training - train the model - validate the model - calculate the loss for each epoch and add it to the loss_trace - print the last losses and scores after every 50 epochs - early stopping - update the training parameters - save the training results and plots
- Parameters:
train_loader (torch.utils.data.DataLoader) – training data loader
val_loader (torch.utils.data.DataLoader) – validation data loader
val_lengths (list) – list of lengths of the validation data
trial (optuna.trial) – optuna trial
- update_training_params(epoch)
Update the learning rate and add it to the learning_rate_dict
- Parameters:
epoch (int) – current epoch
- validation_epoch(val_loader, val_lengths, available_seq)
Validate the model for one epoch by looping through the batches - turn off the gradients - zero the hidden states - forward pass and compute loss - calculate the loss for each batch and add it to the loss_trace
- Parameters:
val_loader (torch.utils.data.DataLoader) – validation data loader
val_lengths (list) – list of lengths of the validation data
available_seq (int) – available sequence length