Building on our previous exploration of GPT-1, let’s now modify its architecture to recreate the GPT-2 [1] small model, which contains 124 million parameters. Although the original paper refers to it as 117M, OpenAI later clarified the actual count.

A key advantage of GPT-2 is that OpenAI released its pre-trained weights, which we can directly load into our implementation. This not only serves as a sanity check for our model but also provides a strong foundation for fine-tuning.

The official model code is available here: model.py

Architecture

Let’s start with our imports:

# Import functions
import torch
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using", device)

Before implementing the model, let’s take a look at GPT-2’s pre-trained weights from Hugging Face’s transformers library. This will help us understand the naming conventions used in the architecture:

from transformers import GPT2LMHeadModel

# Load pre-trained GPT-2 (small) and retrieve its state dictionary
model_hf = GPT2LMHeadModel.from_pretrained("gpt2").to(device) # 124M param model
model_hf_params = model_hf.state_dict()

for name, param in model_hf_params.items():
    print(name, param.shape)
transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])

...

transformer.ln_f.weight torch.Size([768])
transformer.ln_f.bias torch.Size([768])
lm_head.weight torch.Size([50257, 768])

The output lists the layer names along with their corresponding tensor shapes, giving us insights into GPT-2’s layer structure:

  • The model is encapsulated under the namespace transformer.
  • wte represents the word token embedding weights.
  • wpe represents the word position embedding weights.
  • h is a list of hidden decoder blocks.
  • LayerNorms are denoted as ln_1 and ln_2 within each decoder block, with the final LayerNorm represented as ln_f. The scale and shift parameters are renamed to weight and bias, respectively.
  • Our W_QKV weight matrix is renamed to c_attn (short for causal attention), while W_O is renamed to c_proj. These also include bias terms.
  • lm_head represents the final linear layer (language modeling head) and does not include a bias term.
  • The mask variable, which we defined as register_buffer in our implementation does not appear here (even though it be present in our model’s state dict) because it is not a trainable parameter.

To ensure compatibility when loading OpenAI’s pre-trained weights, we will follow this naming convention and re-write our building blocks:

class MultiHeadSelfAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        context_len: int | None = None,
        dropout: float = 0.0,
        causal: bool = True,
    ):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # GPT-2 naming
        self.c_attn = nn.Linear(d_model, 3 * d_model)  
        self.c_proj = nn.Linear(d_model, d_model)      
        self.attn_dropout = nn.Dropout(dropout)

        if causal and context_len is not None:
            self.register_buffer("mask", torch.tril(torch.ones(context_len, context_len)))
        else:
            self.mask = None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape

        # Compute Q, K, V in one go: [B, N, d_model]
        QKV = self.c_attn(x)
        Q, K, V = QKV.chunk(3, dim=-1)
        
        # Split into H heads: [B, N, H, d_h] and then transpose to [B, H, N, d_h]
        Q = Q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Apply scaled dot-product attention (with causal mask) on each head
        attn_scores = (Q @ K.transpose(-2, -1)) * (self.head_dim ** -0.5)
        
        if self.mask is not None:
            attn_scores = attn_scores.masked_fill(self.mask[:N, :N] == 0, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        out = attn_weights @ V

        # Concatenate: transpose back to [B, N, H, d_h], then combine heads [B, N, d_model]
        out = out.transpose(1, 2).contiguous().view(B, N, -1)
        out = self.c_proj(out)

        return out


class LayerNorm(nn.Module):
    def __init__(self, emb_dim: int):
        super().__init__()
        self.eps = 1e-5
        # GPT-2 naming
        self.weight = nn.Parameter(torch.ones(emb_dim))  
        self.bias = nn.Parameter(torch.zeros(emb_dim))   

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight * norm_x + self.bias


class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        )) 


class MLP(nn.Module):
    def __init__(self, emb_dim: int):
        super().__init__()
        # GPT-2 naming
        self.c_fc = nn.Linear(emb_dim, 4 * emb_dim)
        self.c_proj = nn.Linear(4 * emb_dim, emb_dim)
        self.gelu = GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.gelu(self.c_fc(x))
        x = self.c_proj(x)
        return x

Note that the only change in this code is the renaming of variables to match the GPT-2 state dictionary. The GPT-2 decoder block largely mirrors the GPT-1 model we implemented earlier, with one key modification.

Pre-Norm Transformer

The original Transformer (and GPT-1) used Post-Norm residual connections, where LayerNorm is applied after both the sublayer and the residual addition:

$$ \text{Output} = \text{LayerNorm} (x + \text{Sublayer}(x)) $$

Pre-Norm vs Post-Norm Transformer Layers

Pre-Norm vs Post-Norm Transformer Layers

In this setup, the layer normalization sits inside the residual path. Because of this, the model cannot easily learn the identity mapping — the residual branch is no longer a pure skip connection but instead passes through a normalization operation. This makes the gradient propagation difficult, especially in deep networks.

To address this, Xiong et al. [3] proposed the Pre-Norm formulation, where LayerNorm is applied before the sublayer:

$$ \text{Output} = x + (\text{Sublayer}(\text{LayerNorm} (x))) $$

This simple reordering keeps the residual stream as a clean identity path, allowing gradients to propagate more directly and ensuring that each sublayer receives well-scaled inputs.

Most modern Large Language Models, including GPT-2, GPT-3 and beyond, have adopted this Pre-Norm approach. The architectural changes are:

  1. Layer Normalization is moved to the input of each sub-block, ensuring that each sublayer receives normalized inputs.

  2. An additional Layer Normalization is added after the final decoder block, further improving training stability.

With these changes, the GPT-2 architecture now looks like this:

GPT-2 architecture

GPT-2 architecture

class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, context_len: int, num_heads: int, dropout: float):
        super().__init__()
        self.ln_1 = LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(
            d_model=d_model,
            num_heads=num_heads,
            context_len=context_len,
            dropout=dropout
        )
        self.ln_2 = LayerNorm(d_model)
        self.mlp = MLP(emb_dim=d_model)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.resid_dropout(self.attn(self.ln_1(x)))
        x = x + self.resid_dropout(self.mlp(self.ln_2(x)))
        return x

Defining the GPT-2 model

And our GPT-2 model will be coded as:

class GPT2(nn.Module):
    def __init__(self, cfg: dict):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(cfg['vocab_size'], cfg['emb_dim']),      # Token embeddings
            wpe = nn.Embedding(cfg['context_len'], cfg['emb_dim']),     # Positional embeddings
            embd_dropout = nn.Dropout(cfg['dropout']),                  # Embedding dropout
            h = nn.Sequential(*[Block(
                d_model=cfg['emb_dim'],
                context_len=cfg['context_len'],
                num_heads=cfg['n_heads'],
                dropout=cfg['dropout']
            ) for _ in range(cfg['n_layers'])]),
            ln_f = LayerNorm(cfg['emb_dim']),                           # Final LayerNorm
        ))
        self.lm_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)  # LM head

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N = x.size()
        device = x.device
        token_emb = self.transformer.wte(x)                               # [B, N, D]
        pos_emb = self.transformer.wpe(torch.arange(N, device=device))    # [N, D]
        x = token_emb + pos_emb                                           # [B, N, D]
        x = self.transformer.embd_dropout(x)
        x = self.transformer.h(x)                                         
        x = self.transformer.ln_f(x)                                      
        logits = self.lm_head(x)                                          # [B, N, vocab_size]
        return logits

The hyperparameters of our model are defined in a Python dictionary, cfg, which is passed when instantiating the model.

# Define configuration dictionary
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # 50,000 BPE merges + 256 byte tokens + 1 <|endoftext|> token
    "context_length": 1024, # Maximum sequence length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of transformer blocks
    "dropout": 0.1,         # Dropout probability
}

model = GPT2(GPT_CONFIG_124M).to(device)

# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
163.037184 M parameters

Oops! Our model has 163M parameters, even though we aimed to replicate the 124M parameter version. What’s going on?

Actually, nothing is wrong! The discrepancy arises due to the weight tying scheme used in the GPT-2 model. Let’s dive into it in more detail.

Weight tying

The first layer of our architecture is the token embedding table, which maps each token in our vocabulary to an embedding vector. The last layer is the language modeling head, which reverses this process by mapping the embeddings back to vocabulary space.

If we inspect our model’s state dictionary, we find that both layers have the same shape:

gpt2_model_params = model.state_dict()

print(gpt2_model_params["transformer.wte.weight"].shape)
print(gpt2_model_params["lm_head.weight"].shape)
torch.Size([50257, 768])
torch.Size([50257, 768])

This occurs because of how nn.Embedding and nn.Linear store their weights internally:

  • nn.Embedding(dim_in, dim_out) stores its lookup table as torch.Size([dim_in, dim_out]).
  • nn.Linear(features_in, features_out) stores its weight matrix as torch.Size([features_out, features_in]) to perform the y = x @ W.T operation.

This observation led to the practice of weight tying, where the token embedding and the language modeling head share the same set of weights. This approach was also used in the original Transformer paper and is the primary reason why there is no bias term in the language modeling head.

By enforcing weight sharing, we inject an inductive bias into the model — the same embedding matrix is used both to interpret tokens and to predict them. This:

  • Reduces parameters, since we no longer store two large matrices of size [vocab_size, emb_dim].
  • Improves generalization, because the model learns a unified semantic space for both input and output tokens.
  • Empirically shown to improve performance.

We can enable weight tying by directly assigning the same weight tensor:

# Weight tying
model.transformer.wte.weight = model.lm_head.weight

# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')
124.439808 M parameters

Hurray! We now have our own GPT-2 124M model.

Loading Weights

To load weights from HuggingFace, we simply copy the tensor values over to our model. For this, the tensor shapes need to match exactly between the HuggingFace model and our custom model.

Let’s compare the state dictionaries of both models side by side to identify any mismatches in tensor shapes:

for name, param in model_hf_params.items():
    if param.shape != gpt2_model_params[name].shape:
        print(f"Shape mismatch for parameter {name}: {param.shape} vs {gpt2_model_params[name].shape}")
Shape mismatch for parameter transformer.h.0.attn.c_attn.weight: torch.Size([768, 2304]) vs torch.Size([2304, 768])
Shape mismatch for parameter transformer.h.0.mlp.c_fc.weight: torch.Size([768, 3072]) vs torch.Size([3072, 768])
Shape mismatch for parameter transformer.h.0.mlp.c_proj.weight: torch.Size([3072, 768]) vs torch.Size([768, 3072])
...

Oops, why don’t the shapes match?

OpenAI’s GPT-2 checkpoints use a Conv1d module on each linear layer in the GPT-2 architecture. This is why the tensors are transposed and do not match directly.

As a result, the layers c_attn, c_proj, c_fc and c_proj need to be handled differently. We transpose the weights before copying them to ensure they match the expected shapes and are copied correctly.

Here’s how to load the weights:

# Loading weights in your model
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

with torch.no_grad():
  for name, param in model_hf_params.items():
      # check if the parameter name matches
      if name in gpt2_model_params:                
        
        # if the parameter has to be transposed
        if name.endswith(tuple(transposed)):       
          gpt2_model_params[name].copy_(param.t())  # Tranpose the weights and then copy
        
        # if the parameter shape matches directly
        elif param.shape == gpt2_model_params[name].shape:
              gpt2_model_params[name].copy_(param)  # copy the weights over
        else:
            print(f"Shape mismatch for parameter {name}: {param.shape} vs {gpt2_model_params[name].shape}")
      else:
          print(f"Parameter {name} not found in your model")

print("Weights are loaded successfully!")

This code will load the pre-trained HuggingFace weights into our model, handling special treatment for the linear layers to ensure that everything is correctly aligned.

Generating text

Now that we’ve successfully loaded the weights, it’s time to test our model by generating some text. We’ll use the tiktoken tokenizer to encode the initial prompt, “Hello, I’m a language model,” and pass it into the model to generate a continuation.

Here’s a basic text generation loop:

# Generate text from the trained model
context_length = 1024
max_new_tokens = 20
model.eval()

import tiktoken
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode("Hello, I'm a language model,")

# Adds the batch dimension: [B=1, N]
context = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)    

for _ in range(max_new_tokens):
    with torch.no_grad():
        # Trim to context length supported by the model
        idx_cond = context[:, -context_len:]

        logits = model(idx_cond)                            # [B, N, vocab_size]
        logits = logits[:, -1, :]                           # last time step → [B, vocab_size]
        probs = torch.softmax(logits, dim=-1)

        # Select the most likely next token (B=1, N) - greedy approach
        next_token = torch.argmax(probs, dim=-1, keepdim=True) 

        context = torch.cat((context, next_token), dim=1)     # [B=1, N+1]
        
print(enc.decode(stored_context[0].tolist()))                               
Hello, I'm a language model, not a programming language. I'm a language model. I'm a language model. I'm a

The model gets stuck in a loop, generating repeated phrases like “I’m a language model.” This is because we are always selecting the token with the highest probability at each step, which limits the model’s creativity and causes repetition. To resolve this, we can explore probabilistic sampling methods.

Probabilistic Sampling

To introduce more variety and creativity in the decoding process, we can replace the argmax function with multinomial. This method uses the probability distribution output by the model to sample the next token proportionally to its probability score:

logits = model(idx_cond)                              # [B, N, vocab_size]
logits = logits[:, -1, :]                             # last time step → [B, vocab_size]
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)  # Sample from the distribution

By using probabilistic sampling, we can explore a range of potential next tokens, leading to more diverse and interesting text.

While this is a good way to sample text, there are other decoding strategies that allow us to control the distribution and selection process to generate more original text.

Temperature Scaling

Let’s understand temperature scaling through an example:

print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1))
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])/0.001, dim=-1))
tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
tensor([0., 0., 0., 0., 1.])

When the magnitudes of logits are large, the softmax output saturates and converges to a one-hot encoding. Temperature scaling works similarly—by dividing the logits by a number greater than zero:

logits = model(idx_cond)
logits = logits[:, -1, :] / temperature                         
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)  

Temperature scaling allows us to control the randomness of the output:

  • Temperature < 1: Produces more confident (sharper) distributions, picking the most likely token almost always.
  • Temperatures > 1: Results in a more uniformly distributed token probabilties, where other tokens are selected more often. This can add more variety but may also produce nonsensical text.
  • Temperature = 1: This is equivalent to not using any temperature scaling.

Top-k sampling

In top-k sampling, we restrict the sampling process to the top-k most likely tokens and exclude the rest by masking their probabilities. This ensures that we avoid sampling very rare tokens while still providing some diversity in the output.

We achieve this by setting the logits of non-selected tokens to negative infinity, so that their softmax probabilities become zero, and the remaining probabilities sum to 1. The implementation is as follows:

def top_k_logits(logits: torch.Tensor, k: int) -> torch.Tensor:
    if k == 0:
        return logits                   # No truncation
    
    values, _ = torch.topk(logits, k=k) # Get top-k values
    min_value = values[:, -1]           # Minimum value in top-k
    return torch.where(logits < min_value, torch.tensor(float('-inf')), logits)

torch.topk retrieves the values of top-k logits in descending order, and the where function sets the logits of tokens below the lowest logit value to negative infinity. This ensures that only the top-k logits contribute to the probability distribution.

Top-p (Nucleus Sampling)

While top-k gives us the ability to select the top-k tokens to consider in the sampling process, top-p dynamically selects the top tokens whose cumulative probability exceeds a certain threshold, denoted by p. Instead of a fixed number k, it adapts based on the distribution.

For example, we first sort the tokens by probability:

Token A: 0.40  
Token B: 0.30  
Token C: 0.20  
Token D: 0.05  
Token E: 0.05  

Next, we compute the cumulative probability:

Token A: 0.40  
Token A + B: 0.70  
Token A + B + C: 0.90   (stop here, because we reached p=0.9)
Token A + B + C + D: 0.95  
Token A + B + C + D + E: 1.00  

We keep only tokens A, B, and C since their cumulative probability exceeds p=0.9. The rest are discarded.

Here’s how it is implemented in code:

def top_p_logits(logits: torch.Tensor, p: float) -> torch.Tensor:
    if p == 1.0:
        return logits

    # Nucleus Sampling
    sorted_logits, _ = torch.sort(logits, dim=-1, descending=True)                 # Sort logits in descending order
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)  # Compute cumulative probabilities

    # Determine number of indices to include, keeping at least one
    num_to_keep = torch.clamp((cumulative_probs <= p).sum(dim=-1) - 1, min=0)
    min_value = sorted_logits[:, num_to_keep]
    return torch.where(logits < min_value, torch.tensor(float('-inf')), logits)

Sampling code

Now, let’s integrate all the sampling strategies we’ve discussed and implement a sample() function, similar to the one in GPT-2. You can find the original GPT-2 implementation here: sample.py

def sample(
    max_new_tokens: int,
    context_length: int,
    start_token: int | None = None,
    context: torch.Tensor | None = None,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
) -> torch.Tensor:
    """
    Generate text tokens autoregressively from the model.
    """
    if start_token is None:
       assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = torch.full((1, 1), start_token, dtype=torch.long, device=device)
    
    model.eval()

    for _ in range(max_new_tokens):
        with torch.no_grad():
            idx_cond = context[:, -context_len:]
            logits = model(idx_cond)                                # [B, N, vocab_size]

            logits = logits[:, -1, :] / temperature                 # Scale logits by temperature
            logits = top_k_logits(logits, k=top_k)                  # Apply top-k filtering
            logits = top_p_logits(logits, p=top_p)                  # Apply top-p (nucleus) 

            probs = torch.softmax(logits, dim=-1)                   # Convert logits to probabilities
            next_token = torch.multinomial(probs, num_samples=1)    # Sample from the distribution 
            context = torch.cat((context, next_token), dim=1)       # [B=1, N+1]
    
    return context

Using this function, we can start generating text either using a start token or an initial prompt.

tokens = enc.encode("Hello, I'm a language model,")
prompt = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)    # Adds the batch dimension(B=1, N)

next_tokens = sample(context=prompt, max_new_tokens=20, context_length=GPT_CONFIG_124M["context_length"], top_p=0.9, top_k=40, temperature=1.0)
print(enc.decode(next_tokens[0].tolist()))       

Here’s the output:

Hello, I'm a language model, and I wanted to do something more powerful than just an English translation of an English sentence."

And with that, we’ve explored some of the most exciting text generation techniques that can bring out the true creativity of language models. Now, it’s your turn to experiment and create your own stories! Happy coding!

Summary of GPT-2 and GPT-3 models

GPT-2 model overview (2019)

Below are some key implementation details of the GPT-2 model:

  • Tokenizer: Uses Byte Pair Encoding (BPE) with a vocabulary size of 50,257.
  • Context length: $N = 1024$ tokens.
  • Architecture hyperparameters:
ModelParametersLayersHidden SizeAttention Heads
GPT-2 Small124 million1276812
GPT-2 Medium355 million24102416
GPT-2 Large774 million36128020
GPT-2 XL1.5 billion48160025
  • Pre-Training: Trained on the WebText dataset created by OpenAI, which contains 40GB of text from 8 million web pages scraped primarily from Reddit.
    • Batch size: 512
    • Initialization: similar to GPT-1

While the GPT-2 paper fully describes the tokenizer and model architecture, it lacks detailed training parameters. The code that OpenAI released—and that we have been referring to so far—is also inference code, with no training specifics mentioned.

GPT-3 model overview (2020)

Architecturally, GPT-3 is identical to GPT-2, except for one key difference: it employs alternating dense and locally banded sparse attention patterns in its transformer layers. Therefore, the authors do not revisit the architectural details already specified in the GPT-2 paper but instead focus on the training details:

  • Context length: $N = 2048$ tokens.

  • Architecture hyperparameters:

    Sizes, architectures, and learning hyper-parameters (batch size in tokens and learning rate) of GPT-3 models.

    Sizes, architectures, and learning hyper-parameters (batch size in tokens and learning rate) of GPT-3 models.

  • Pre-training: Trained on a dataset comprising filtered Common Crawl, WebText2 (an expanded version of WebText), Books1 and Books2 (two internet-based book corpora), and English-language Wikipedia.

    • Optimizer: AdamW

      • $\beta_1 = 0.9, \beta_2 = 0.95, \epsilon = 10^{-8}$
      • Weight decay: $0.1$
      • Batch size:
        • Data is sampled without replacement.
        • Batch size increases linearly from a small value (32k tokens) to the full value over the first 4–12 billion tokens of training, depending on model size.
    • Gradient Clipping: The global norm of the gradient is clipped at 1.0.

    • Learning rate:

      • Increased linearly from zero to a peak value over the first 375M tokens.
      • Then annealed to 10% of its peak value using a cosine schedule over 260B tokens.
      • Training continues at this reduced learning rate.
  • Fine-tuning: GPT-2 and GPT-3 models were not explicitly fine-tuned for specific tasks. Instead, they leveraged in-context learning, relying on zero-shot or few-shot prompting to adapt to various tasks without the need for task-specific fine-tuning.

Comparison of fine-tuning with zero-shot, one-shot, and few-shot learning using an English-to-French translation example.

Comparison of fine-tuning with zero-shot, one-shot, and few-shot learning using an English-to-French translation example.

References