Back to LSTMs Hub

LSTMs from Scratch

Part 2: Pure PyTorch Implementation

Introduction

In Part 1, we explored the mathematics of gated recurrence—the forget gate, input gate, output gate, and the additive cell state update that creates a gradient highway through time. Today, we take that math and translate it into a pure PyTorch implementation. We will build from individual cells to complete sequence models.

The LSTM Cell

Our LSTMCell class implements the four key computations. Each gate is a linear projection of the concatenated input $x_t$ and previous hidden state $h_{t-1}$, followed by a nonlinearity:

lstm_cell.py
# Compute gates
i = torch.sigmoid(W_xi @ x + W_hi @ h_prev)  # Input gate
f = torch.sigmoid(W_xf @ x + W_hf @ h_prev)  # Forget gate
g = torch.tanh(W_xc @ x + W_hc @ h_prev)     # Cell candidate
o = torch.sigmoid(W_xo @ x + W_ho @ h_prev)  # Output gate

# Cell state update (the gradient highway)
c_new = f * c_prev + i * g

# Hidden state update
h_new = o * torch.tanh(c_new)

Weight Initialization

We use Xavier initialization for all weights, with a critical modification: forget gate biases are initialized to 1.0. This encourages the network to start with information preservation rather than information destruction, stabilizing early training. This trick, first noted by Jozefowicz et al. (2015), is essential for consistent convergence.

lstm_cell.py
# Critical: Initialize forget gate bias to 1.0
nn.init.ones_(self.Wxf.bias)
nn.init.ones_(self.Whf.bias)

Multi-Layer LSTMs

Stacking LSTM cells creates hierarchical temporal processing:

The output of each layer's hidden state $h_t^{(l)}$ becomes the input for the next layer. We apply dropout only between layers, never at the final output—matching the convention used in PyTorch's built-in LSTM.

Architecture Variants

LSTM Classifier (Many-to-One)

For sequence classification tasks:

LSTM Tagger (Many-to-Many)

For sequence labeling (POS tagging, NER):

CharLSTM

Character-level language model:

char_lstm.py
def generate(self, start_token, seq_len, temperature=1.0):
    # Sample next character from the distribution
    next_logits = logits[:, -1, :] / temperature
    probs = torch.softmax(next_logits, dim=-1)
    next_token = torch.multinomial(probs, 1)

Seq2SeqLSTM (Encoder-Decoder)

The classic encoder-decoder architecture:

Bidirectional Processing

For tasks requiring full context (e.g., sentiment analysis, NER), we implement bidirectional LSTMs:

Training Strategy

We train on a challenging long-range dependency task specifically designed to expose the vanishing gradient problem:

Gradient clipping prevents the complementary exploding gradient problem:

train.py
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Next Steps: Visualizing Gate Dynamics

With our LSTM implemented and architecture variants in place, we can now train and analyze what the gates actually learn. In Part 3, we will visualize forget gate, input gate, and output gate activations across time—watching the network learn to selectively remember and forget.

Full code is on the GitHub repo. Stay tuned for the gate visualization drop!