#Trying for a deep understanding of how the `torch.nn.lstm` works by visualize it

6 messages · Page 1 of 1 (latest)

feral shard
#

I am trying to visualize the LSTM neural network module of PyTorch correctly, but I need help based on my imperfect knowledge.
This is my current visualization. Is it correct? If not, help me to make it perfect.

spring pike
#

Maybe u can begin making it for a univariate perspertive first

feral shard
feral shard
#

Ok, I wrote a sample code with different values for each variable of the network to be a reference for what I am trying to visualize.

The code:

import torch


class LSTMModel(torch.nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int, output_size: int) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, (hn, cn) = self.lstm(x, (h0, c0))
        out = out[:, :, -self.output_size:]
        return out


BATCH_SIZE = 2
SEQUENCE_LENGTH = 8
INPUT_SIZE = 9
HIDDEN_SIZE = 16
OUTPUT_SIZE = 5
NUM_LAYERS = 4

input_sample = torch.rand(BATCH_SIZE, SEQUENCE_LENGTH, INPUT_SIZE)
model = LSTMModel(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, OUTPUT_SIZE)
output = model(input_sample)

# input shape must be (N, L, Hin) => (2, 8, 9)
print(input_sample.shape)
# output shape must be (N, L, D * Hout) => if D = 1 => (2, 8, 5)
print(output.shape)

The Output

❯ python src/model.py
torch.Size([2, 8, 9])
torch.Size([2, 8, 5])

Then I made a new model based on the code. I also added the initialized and final nodes for hidden and cell states
Note: the sequence length of 8 not visible from this sight

Is this the right model for the code?

spring pike
#

hmm maybe i am trying to focus in the wrong thing, but what i mean was to begin with something like this

#

so going from an easy example to a complex one, start only with 1 feature in an auto-regresive way