In Part 1, we explored the mathematics of recurrence. Now we implement RNNs in PyTorch, building from individual cells to complete sequence models.
The RNN Cell
Our RNNCell class implements the core recurrence:
h_new = torch.tanh(W_xh @ x + W_hh @ h_prev + b)
Key design choices:
- Separate weight matrices for input and hidden projections
- Xavier initialization to stabilize training
- Optional bias terms
Multi-Layer RNNs
Stacking RNN cells creates deeper representations:
- Layer 1 processes the raw input sequence
- Layer 2 processes the hidden states from Layer 1
- Each additional layer learns higher-level temporal patterns
Architecture Variants
Sequence Classifier
Many-to-one architecture:
Input: (batch, seq_len, input_size)
Output: (batch, num_classes)
Uses the final hidden state for classification.
Sequence Tagger
Many-to-many architecture:
Input: (batch, seq_len, input_size)
Output: (batch, seq_len, num_tags)
Outputs a prediction at each time step.
CharRNN
Character-level language model:
- Embedding layer for character tokens
- Multi-layer RNN for sequence processing
- Output projection to vocabulary logits
- Autoregressive generation via sampling
Bidirectional RNN
Processes sequences in both directions:
- Forward RNN: Past to future
- Backward RNN: Future to past
- Concatenate both representations
This captures context from both directions, crucial for tasks like named entity recognition.
Implementation Details
Batch-First vs Time-First
We use batch-first tensors (B, T, D) for compatibility with PyTorch conventions, but internally process time step by time step.
Dropout Regularization
Applied between RNN layers (not across time) to prevent overfitting:
if dropout > 0:
self.dropout = nn.Dropout(dropout)
Hidden State Initialization
Hidden states are initialized to zero at the start of each sequence:
Training Strategy
We train on synthetic sequence tasks:
- Sequence classification: Predict label from entire sequence
- Cross-entropy loss
- Adam optimizer with learning rate scheduling
Conclusion
With our RNN implemented, we can now train it and visualize how hidden states evolve. Part 3 explores hidden state dynamics and analyzes what the network learns.