In Part 1, we explored the mathematics of gated recurrence—the forget gate, input gate, output gate, and the additive cell state update that creates a gradient highway through time. Today, we take that math and translate it into a pure PyTorch implementation. We will build from individual cells to complete sequence models.
The LSTM Cell
Our LSTMCell class implements the four key computations. Each gate is a linear projection of the concatenated input $x_t$ and previous hidden state $h_{t-1}$, followed by a nonlinearity:
# Compute gates
i = torch.sigmoid(W_xi @ x + W_hi @ h_prev) # Input gate
f = torch.sigmoid(W_xf @ x + W_hf @ h_prev) # Forget gate
g = torch.tanh(W_xc @ x + W_hc @ h_prev) # Cell candidate
o = torch.sigmoid(W_xo @ x + W_ho @ h_prev) # Output gate
# Cell state update (the gradient highway)
c_new = f * c_prev + i * g
# Hidden state update
h_new = o * torch.tanh(c_new)
Weight Initialization
We use Xavier initialization for all weights, with a critical modification: forget gate biases are initialized to 1.0. This encourages the network to start with information preservation rather than information destruction, stabilizing early training. This trick, first noted by Jozefowicz et al. (2015), is essential for consistent convergence.
# Critical: Initialize forget gate bias to 1.0
nn.init.ones_(self.Wxf.bias)
nn.init.ones_(self.Whf.bias)
Multi-Layer LSTMs
Stacking LSTM cells creates hierarchical temporal processing:
- Lower layers capture short-term patterns (local token interactions)
- Higher layers learn long-term structure (global sentence-level meaning)
- Dropout between layers prevents overfitting
The output of each layer's hidden state $h_t^{(l)}$ becomes the input for the next layer. We apply dropout only between layers, never at the final output—matching the convention used in PyTorch's built-in LSTM.
Architecture Variants
LSTM Classifier (Many-to-One)
For sequence classification tasks:
- Processes the entire sequence through multi-layer LSTM
- Concatenates final hidden states from all layers
- Projects to class logits via a fully-connected layer
- Optional bidirectional support doubles the representation power
LSTM Tagger (Many-to-Many)
For sequence labeling (POS tagging, NER):
- Outputs a prediction at each time step
- Projects each hidden state $h_t$ to output vocabulary
CharLSTM
Character-level language model:
- Embedding layer maps character indices to dense vectors
- Multi-layer LSTM processes the embedded sequence
- Autoregressive generation uses temperature-controlled sampling
def generate(self, start_token, seq_len, temperature=1.0):
# Sample next character from the distribution
next_logits = logits[:, -1, :] / temperature
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
Seq2SeqLSTM (Encoder-Decoder)
The classic encoder-decoder architecture:
- Encoder LSTM compresses the source sequence into a fixed-size context vector (its final hidden and cell states)
- Decoder LSTM is initialized with the encoder's final states and generates the target sequence autoregressively
Bidirectional Processing
For tasks requiring full context (e.g., sentiment analysis, NER), we implement bidirectional LSTMs:
- Forward LSTM processes left-to-right: $h_t^{\rightarrow}$
- Backward LSTM processes right-to-left: $h_t^{\leftarrow}$
- Concatenate both: $h_t = [h_t^{\rightarrow}; h_t^{\leftarrow}]$
Training Strategy
We train on a challenging long-range dependency task specifically designed to expose the vanishing gradient problem:
- Input: random sequence of length 30
- Task: classify based on the first + last element only
- The model must remember information across 30 time steps while ignoring irrelevant intermediate values
Gradient clipping prevents the complementary exploding gradient problem:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Next Steps: Visualizing Gate Dynamics
With our LSTM implemented and architecture variants in place, we can now train and analyze what the gates actually learn. In Part 3, we will visualize forget gate, input gate, and output gate activations across time—watching the network learn to selectively remember and forget.
Full code is on the GitHub repo. Stay tuned for the gate visualization drop!