Back to LSTMs Hub

LSTMs from Scratch

Part 2: Pure PyTorch Implementation

Introduction

Part 1 covered the math: forget gate, input gate, output gate, and the additive cell state update. Now we translate those equations into working PyTorch code, building from a single cell up to full sequence models.

The LSTM Cell

Each gate is a linear projection of the concatenated input $x_t$ and previous hidden state $h_{t-1}$, followed by a nonlinearity:

lstm_cell.py
# Compute gates
i = torch.sigmoid(W_xi @ x + W_hi @ h_prev)  # Input gate
f = torch.sigmoid(W_xf @ x + W_hf @ h_prev)  # Forget gate
g = torch.tanh(W_xc @ x + W_hc @ h_prev)     # Cell candidate
o = torch.sigmoid(W_xo @ x + W_ho @ h_prev)  # Output gate

# Cell state update (the gradient highway)
c_new = f * c_prev + i * g

# Hidden state update
h_new = o * torch.tanh(c_new)

Weight Initialization

All weights use Xavier initialization, with one important detail: forget gate biases start at 1.0. Without this, the forget gate sigmoid outputs ~0.5 at initialization, immediately discarding half the cell state. Setting the bias to 1 pushes the initial output toward 1 (preserve everything), which stabilizes early training. Jozefowicz et al. (2015) showed this is critical for consistent convergence.

lstm_cell.py
# Critical: Initialize forget gate bias to 1.0
nn.init.ones_(self.Wxf.bias)
nn.init.ones_(self.Whf.bias)

Multi-Layer LSTMs

Stacking cells gives you hierarchical temporal processing:

Each layer's hidden state $h_t^{(l)}$ feeds into the layer above it. Dropout is applied between layers only, not after the final output -- matching PyTorch's built-in LSTM convention.

Architecture Variants

LSTM Classifier (Many-to-One)

For sequence classification:

LSTM Tagger (Many-to-Many)

For sequence labeling (POS tagging, NER):

CharLSTM

Character-level language model:

char_lstm.py
def generate(self, start_token, seq_len, temperature=1.0):
    # Sample next character from the distribution
    next_logits = logits[:, -1, :] / temperature
    probs = torch.softmax(next_logits, dim=-1)
    next_token = torch.multinomial(probs, 1)

Seq2SeqLSTM (Encoder-Decoder)

The encoder-decoder architecture:

Bidirectional Processing

When the model needs full left-right context (sentiment analysis, NER), we run two LSTMs in opposite directions:

Training Strategy

We test on a synthetic long-range dependency task designed to expose vanishing gradients:

Gradient clipping handles the complementary exploding gradient problem:

train.py
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Next: Visualizing Gate Dynamics

The implementation is done. In Part 3, we train this model and look at what the gates actually learn -- extracting forget, input, and output gate activations across time to see how the network decides what to remember and what to discard.