Back to Transformers Hub

Transformers from Scratch

Part 2: Pure PyTorch Implementation

Introduction

Welcome to Part 2 of our series on Transformers. In Part 1, we outlined the mathematical foundations of Self-Attention and Positional Encodings. Now, it is time to turn those equations into running code.

Rather than relying on the high-level nn.Transformer module, we are building the entire sequence-to-sequence model from scratch in PyTorch. This requires precise management of tensor dimensions, multi-head reshapes, and causality masks.

Implementing Scaled Dot-Product Attention

The attention mechanism requires projecting our inputs, performing batched matrix multiplications, and applying masking.

attention.py
class ScaledDotProductAttention(nn.Module):
    def forward(self, q, k, v, mask=None):
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, v)
        return output, attention_weights

Notice the use of -1e9 when applying the mask. When we pass this tensor through the softmax activation, those values accurately converge to zero, ensuring the model cannot attend to padded or future tokens.

Constructing the Encoder and Decoder Layers

The Transformer architecture is symmetric, built of layered building blocks.

The Encoder Layer consists of:

  1. Multi-Head Self-Attention
  2. Residual Connection & Layer Normalization
  3. Position-wise Feed-Forward Network
  4. Residual Connection & Layer Normalization

The Decoder Layer is slightly more complex. It adds a Cross-Attention block in the middle, which acts as the bridge between the encoder and the decoder. In cross-attention, the Queries come from the previous Decoder layer, while the Keys and Values are provided by the Encoder's final output.

The Future-Blind Mask

One of the trickiest parts of the implementation is the Look-Ahead Mask (or causal mask) in the decoder. When generating text sequentially, the model must be blind to future tokens during training (teacher forcing). We achieve this by applying a lower-triangular boolean mask matrix to the self-attention scores in the decoder.

masking.py
tgt_lookahead_mask = torch.tril(torch.ones((tgt_len, tgt_len))).bool()

Up Next

With the code complete, we now have a working Transformer. But how do we know it learns? In Part 3, we will set up a sequence-to-sequence training loop, visualize the loss curve, and extract the internal cross-attention tensors to visualize how the model routes information.