Arcana Models: Decoders
Submodules
arcana.models.decoders.additive_decoder module
Additive dencoder module for the ARCANA project.
- class arcana.models.decoders.additive_decoder.AdditiveDecoder(*args: Any, **kwargs: Any)
Bases:
BaseDecoder
Additive decoder module
- forward(x_tensor, hidden_state, cell_state, encoder_outputs)
Forward pass for additive decoder module. The forward pass is implemented as follows:
get attention scores
create context vector
concatenate the context vector with the input tensor
pass the concatenated tensor through the lstm layer
pass the lstm output through the fc layer
pass the fc layer output through the leaky relu layer
pass the leaky relu output through the output dropout layer
pass the output dropout layer output through the fc layer for quantile 1, 2 and 3
concatenate the quantile 1, 2 and 3 predictions
- Parameters:
x_tensor (torch.Tensor) – input tensor
hidden_state (torch.Tensor) – hidden state
cell_state (torch.Tensor) – cell state
encoder_outputs (torch.Tensor) – encoder outputs
- Returns:
predictions (tuple) – tuple of quantile predictions
hidden_out (torch.Tensor) – hidden state
cell_out (torch.Tensor) – cell state
attention_scores (torch.Tensor) – attention scores
- class arcana.models.decoders.additive_decoder.Attention(*args: Any, **kwargs: Any)
Bases:
Module
Additive attention module
- forward(hidden, encoder_outputs)
Forward pass for additive attention module
The forward pass is implemented as follows:
transform the hidden state of the decoder to the same size as the hidden state of the encoder
add the transformed hidden state of the decoder and the hidden state of the encoder
apply tanh activation to the sum
transform the tanh output to a scalar
apply softmax to the scalar
- Parameters:
hidden (torch.Tensor) – hidden state of the decoder. The shape is (batch_size, hidden_size)
encoder_outputs (torch.Tensor) – hidden state of the encoder. The shape is (batch_size, seq_length, hidden_size)
- Returns:
attention_scores (torch.Tensor) – attention scores. The shape is (batch_size, seq_length)
arcana.models.decoders.base_decoder module
Base decoder class
- class arcana.models.decoders.base_decoder.BaseDecoder(*args: Any, **kwargs: Any)
Bases:
Module
Base decoder module
- forward(x_tensor, hidden_state, cell_state, encoder_outputs)
Forward pass to be implemented by subclass
- arcana.models.decoders.base_decoder.initialize_weights(layer)
Initialize weights for the layer
arcana.models.decoders.multihead_decoder module
Multihead dencoder module for the ARCANA project.
- class arcana.models.decoders.multihead_decoder.MultiheadDecoder(*args: Any, **kwargs: Any)
Bases:
BaseDecoder
Multihead decoder module
- forward(x_tensor, hidden_state, cell_state, encoder_outputs)
Forward pass for multihead decoder module.
- The forward pass is implemented as follows:
get attention scores
concatenate the attention scores with the input tensor
pass the concatenated tensor through the lstm layer
pass the lstm output through the fc layer
pass the fc layer output through the leaky relu layer
pass the leaky relu output through the output dropout layer
pass the output dropout layer output through the fc layer for quantile 1, 2 and 3
concatenate the quantile 1, 2 and 3 predictions
- Parameters:
x_tensor (torch.Tensor) – input tensor (batch_size, seq_length, input_size)
hidden_state (torch.Tensor) – hidden state (num_layers, batch_size, hidden_size)
cell_state (torch.Tensor) – cell state (num_layers, batch_size, hidden_size)
encoder_outputs (torch.Tensor) – encoder outputs (batch_size, seq_length, hidden_size)
- Returns:
predictions (tuple) – tuple of quantile predictions
hidden_out (torch.Tensor) – hidden state
cell_out (torch.Tensor) – cell state
attention_output_weights_probs_decoder (torch.Tensor) – attention scores