5. LLM Architecture

LLM Architecture

рдЗрд╕ рдкрд╛рдВрдЪрд╡реЗ рдЪрд░рдг рдХрд╛ рд▓рдХреНрд╖реНрдп рдмрд╣реБрдд рд╕рд░рд▓ рд╣реИ: рдкреВрд░реНрдг LLM рдХреА рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рд╡рд┐рдХрд╕рд┐рдд рдХрд░рдирд╛ред рд╕рдм рдХреБрдЫ рдПрдХ рд╕рд╛рде рд░рдЦреЗрдВ, рд╕рднреА рдкрд░рддреЛрдВ рдХреЛ рд▓рд╛рдЧреВ рдХрд░реЗрдВ рдФрд░ рдкрд╛рда рдЙрддреНрдкрдиреНрди рдХрд░рдиреЗ рдпрд╛ рдкрд╛рда рдХреЛ IDs рдореЗрдВ рдФрд░ рдкреАрдЫреЗ рдХреА рдУрд░ рдмрджрд▓рдиреЗ рдХреЗ рд▓рд┐рдП рд╕рднреА рдХрд╛рд░реНрдпреЛрдВ рдХреЛ рдмрдирд╛рдПрдВред

рдпрд╣ рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдФрд░ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рджреЛрдиреЛрдВ рдХреЗ рд▓рд┐рдП рдЙрдкрдпреЛрдЧ рдХрд┐рдпрд╛ рдЬрд╛рдПрдЧрд╛, рдЬрдм рдЗрд╕реЗ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╣реЛред

LLM рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдХрд╛ рдЙрджрд╛рд╣рд░рдг https://github.com/rasbt/LLMs-from-scratch/blob/main/ch04/01_main-chapter-code/ch04.ipynb:

рдПрдХ рдЙрдЪреНрдЪ рд╕реНрддрд░ рдХрд╛ рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рджреЗрдЦрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ:

  1. Input (Tokenized Text): рдкреНрд░рдХреНрд░рд┐рдпрд╛ рдЯреЛрдХрдирдпреБрдХреНрдд рдкрд╛рда рдХреЗ рд╕рд╛рде рд╢реБрд░реВ рд╣реЛрддреА рд╣реИ, рдЬрд┐рд╕реЗ рд╕рдВрдЦреНрдпрд╛рддреНрдордХ рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рдореЗрдВ рдкрд░рд┐рд╡рд░реНрддрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

  2. Token Embedding and Positional Embedding Layer: рдЯреЛрдХрдирдпреБрдХреНрдд рдкрд╛рда рдХреЛ рдЯреЛрдХрди рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдкрд░рдд рдФрд░ рдкреЛрдЬрд┐рд╢рдирд▓ рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдкрд░рдд рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рднреЗрдЬрд╛ рдЬрд╛рддрд╛ рд╣реИ, рдЬреЛ рдЕрдиреБрдХреНрд░рдо рдореЗрдВ рдЯреЛрдХрдиреЛрдВ рдХреА рд╕реНрдерд┐рддрд┐ рдХреЛ рдХреИрдкреНрдЪрд░ рдХрд░рддрд╛ рд╣реИ, рдЬреЛ рд╢рдмреНрдж рдХреНрд░рдо рдХреЛ рд╕рдордЭрдиреЗ рдХреЗ рд▓рд┐рдП рдорд╣рддреНрд╡рдкреВрд░реНрдг рд╣реИред

  3. Transformer Blocks: рдореЙрдбрд▓ рдореЗрдВ 12 рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХреНрд╕ рд╣реЛрддреЗ рд╣реИрдВ, рдкреНрд░рддреНрдпреЗрдХ рдореЗрдВ рдХрдИ рдкрд░рддреЗрдВ рд╣реЛрддреА рд╣реИрдВред рдпреЗ рдмреНрд▓реЙрдХреНрд╕ рдирд┐рдореНрдирд▓рд┐рдЦрд┐рдд рдЕрдиреБрдХреНрд░рдо рдХреЛ рджреЛрд╣рд░рд╛рддреЗ рд╣реИрдВ:

  • Masked Multi-Head Attention: рдореЙрдбрд▓ рдХреЛ рдПрдХ рдмрд╛рд░ рдореЗрдВ рдЗрдирдкреБрдЯ рдкрд╛рда рдХреЗ рд╡рд┐рднрд┐рдиреНрди рднрд╛рдЧреЛрдВ рдкрд░ рдзреНрдпрд╛рди рдХреЗрдВрджреНрд░рд┐рдд рдХрд░рдиреЗ рдХреА рдЕрдиреБрдорддрд┐ рджреЗрддрд╛ рд╣реИред

  • Layer Normalization: рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЛ рд╕реНрдерд┐рд░ рдФрд░ рд╕реБрдзрд╛рд░рдиреЗ рдХреЗ рд▓рд┐рдП рдПрдХ рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг рдХрджрдоред

  • Feed Forward Layer: рдзреНрдпрд╛рди рдкрд░рдд рд╕реЗ рдЬрд╛рдирдХрд╛рд░реА рдХреЛ рд╕рдВрд╕рд╛рдзрд┐рдд рдХрд░рдиреЗ рдФрд░ рдЕрдЧрд▓реЗ рдЯреЛрдХрди рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЬрд┐рдореНрдореЗрджрд╛рд░ред

  • Dropout Layers: рдпреЗ рдкрд░рддреЗрдВ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдпрд╛рджреГрдЪреНрдЫрд┐рдХ рд░реВрдк рд╕реЗ рдЗрдХрд╛рдЗрдпреЛрдВ рдХреЛ рдЫреЛрдбрд╝рдХрд░ рдУрд╡рд░рдлрд┐рдЯрд┐рдВрдЧ рдХреЛ рд░реЛрдХрддреА рд╣реИрдВред

  1. Final Output Layer: рдореЙрдбрд▓ рдПрдХ 4x50,257-рдЖрдпрд╛рдореА рдЯреЗрдиреНрд╕рд░ рдЖрдЙрдЯрдкреБрдЯ рдХрд░рддрд╛ рд╣реИ, рдЬрд╣рд╛рдБ 50,257 рд╢рдмреНрджрд╛рд╡рд▓реА рдХреЗ рдЖрдХрд╛рд░ рдХрд╛ рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рдХрд░рддрд╛ рд╣реИред рдЗрд╕ рдЯреЗрдиреНрд╕рд░ рдореЗрдВ рдкреНрд░рддреНрдпреЗрдХ рдкрдВрдХреНрддрд┐ рдПрдХ рд╡реЗрдХреНрдЯрд░ рдХреЗ рдЕрдиреБрд░реВрдк рд╣реЛрддреА рд╣реИ рдЬрд┐рд╕рдХрд╛ рдЙрдкрдпреЛрдЧ рдореЙрдбрд▓ рдЕрдиреБрдХреНрд░рдо рдореЗрдВ рдЕрдЧрд▓реЗ рд╢рдмреНрдж рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдХрд░рддрд╛ рд╣реИред

  2. Goal: рдЙрджреНрджреЗрд╢реНрдп рдЗрди рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдХреЛ рд▓реЗрдирд╛ рдФрд░ рдЙрдиреНрд╣реЗрдВ рдлрд┐рд░ рд╕реЗ рдкрд╛рда рдореЗрдВ рдкрд░рд┐рд╡рд░реНрддрд┐рдд рдХрд░рдирд╛ рд╣реИред рд╡рд┐рд╢реЗрд╖ рд░реВрдк рд╕реЗ, рдЖрдЙрдЯрдкреБрдЯ рдХреА рдЕрдВрддрд┐рдо рдкрдВрдХреНрддрд┐ рдХрд╛ рдЙрдкрдпреЛрдЧ рдЕрдЧрд▓реЗ рд╢рдмреНрдж рдХреЛ рдЙрддреНрдкрдиреНрди рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ, рдЬрд┐рд╕реЗ рдЗрд╕ рдЖрд░реЗрдЦ рдореЗрдВ "рдЖрдЧреЗ" рдХреЗ рд░реВрдк рдореЗрдВ рджрд░реНрд╢рд╛рдпрд╛ рдЧрдпрд╛ рд╣реИред

Code representation

import torch
import torch.nn as nn
import tiktoken

class GELU(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))

class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)

def forward(self, x):
return self.layers(x)

class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

def forward(self, x):
b, num_tokens, d_in = x.shape

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec

class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))

def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift

class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x)  # Shape [batch_size, num_tokens, emb_size]
x = self.drop_shortcut(x)
x = x + shortcut  # Add the original input back

# Shortcut connection for feed forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut  # Add the original input back

return x


class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])

self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(
cfg["emb_dim"], cfg["vocab_size"], bias=False
)

def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits

GPT_CONFIG_124M = {
"vocab_size": 50257,    # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768,         # Embedding dimension
"n_heads": 12,          # Number of attention heads
"n_layers": 12,         # Number of layers
"drop_rate": 0.1,       # Dropout rate
"qkv_bias": False       # Query-Key-Value bias
}

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
out = model(batch)
print("Input batch:\n", batch)
print("\nOutput shape:", out.shape)
print(out)

GELU рд╕рдХреНрд░рд┐рдпрдг рдлрд╝рдВрдХреНрд╢рди

# From https://github.com/rasbt/LLMs-from-scratch/tree/main/ch04
class GELU(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))

рдЙрджреНрджреЗрд╢реНрдп рдФрд░ рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛

  • GELU (Gaussian Error Linear Unit): рдПрдХ рд╕рдХреНрд░рд┐рдпрдг рдлрд╝рдВрдХреНрд╢рди рдЬреЛ рдореЙрдбрд▓ рдореЗрдВ рдЧреИрд░-рд░реЗрдЦреАрдпрддрд╛ рдХреЛ рдкреЗрд╢ рдХрд░рддрд╛ рд╣реИред

  • рд╕реНрдореВрдж рд╕рдХреНрд░рд┐рдпрдг: ReLU рдХреЗ рд╡рд┐рдкрд░реАрдд, рдЬреЛ рдирдХрд╛рд░рд╛рддреНрдордХ рдЗрдирдкреБрдЯ рдХреЛ рд╢реВрдиреНрдп рдХрд░ рджреЗрддрд╛ рд╣реИ, GELU рдЗрдирдкреБрдЯ рдХреЛ рдЖрдЙрдЯрдкреБрдЯ рдореЗрдВ рд╕реНрдореВрдж рддрд░реАрдХреЗ рд╕реЗ рдореИрдк рдХрд░рддрд╛ рд╣реИ, рдЬрд┐рд╕рд╕реЗ рдирдХрд╛рд░рд╛рддреНрдордХ рдЗрдирдкреБрдЯ рдХреЗ рд▓рд┐рдП рдЫреЛрдЯреЗ, рдЧреИрд░-рд╢реВрдиреНрдп рдорд╛рдиреЛрдВ рдХреА рдЕрдиреБрдорддрд┐ рдорд┐рд▓рддреА рд╣реИред

  • рдЧрдгрд┐рддреАрдп рдкрд░рд┐рднрд╛рд╖рд╛:

рдЗрд╕ рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдЙрдкрдпреЛрдЧ FeedForward рд▓реЗрдпрд░ рдХреЗ рдЕрдВрджрд░ рд░реЗрдЦреАрдп рдкрд░рддреЛрдВ рдХреЗ рдмрд╛рдж рдХрд░рдиреЗ рдХрд╛ рд▓рдХреНрд╖реНрдп рд░реЗрдЦреАрдп рдбреЗрдЯрд╛ рдХреЛ рдЧреИрд░-рд░реЗрдЦреАрдп рдореЗрдВ рдмрджрд▓рдирд╛ рд╣реИ рддрд╛рдХрд┐ рдореЙрдбрд▓ рдЬрдЯрд┐рд▓, рдЧреИрд░-рд░реЗрдЦреАрдп рд╕рдВрдмрдВрдзреЛрдВ рдХреЛ рд╕реАрдЦ рд╕рдХреЗред

FeedForward рдиреНрдпреВрд░рд▓ рдиреЗрдЯрд╡рд░реНрдХ

рдЖрдХреГрддрд┐рдпреЛрдВ рдХреЛ рдореИрдЯреНрд░рд┐рд╕ рдХреЗ рдЖрдХрд╛рд░ рдХреЛ рдмреЗрд╣рддрд░ рд╕рдордЭрдиреЗ рдХреЗ рд▓рд┐рдП рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рдХреЗ рд░реВрдк рдореЗрдВ рдЬреЛрдбрд╝рд╛ рдЧрдпрд╛ рд╣реИ:

# From https://github.com/rasbt/LLMs-from-scratch/tree/main/ch04
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)

def forward(self, x):
# x shape: (batch_size, seq_len, emb_dim)

x = self.layers[0](x)# x shape: (batch_size, seq_len, 4 * emb_dim)
x = self.layers[1](x) # x shape remains: (batch_size, seq_len, 4 * emb_dim)
x = self.layers[2](x) # x shape: (batch_size, seq_len, emb_dim)
return x  # Output shape: (batch_size, seq_len, emb_dim)

рдЙрджреНрджреЗрд╢реНрдп рдФрд░ рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛

  • рдкреЛрдЬреАрд╢рди-рд╡рд╛рдЗрдЬ рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯрд╡рд░реНрдХ: рдкреНрд░рддреНрдпреЗрдХ рдкреЛрдЬреАрд╢рди рдкрд░ рдЕрд▓рдЧ-рдЕрд▓рдЧ рдФрд░ рд╕рдорд╛рди рд░реВрдк рд╕реЗ рджреЛ-рдкрд░рддреЛрдВ рд╡рд╛рд▓рд╛ рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдЬреБрдбрд╝реЗ рдиреЗрдЯрд╡рд░реНрдХ рд▓рд╛рдЧреВ рдХрд░рддрд╛ рд╣реИред

  • рдкрд░рдд рд╡рд┐рд╡рд░рдг:

  • рдкрд╣рд▓реА рд░реИрдЦрд┐рдХ рдкрд░рдд: emb_dim рд╕реЗ 4 * emb_dim рддрдХ рдЖрдпрд╛рдо рдХрд╛ рд╡рд┐рд╕реНрддрд╛рд░ рдХрд░рддрд╛ рд╣реИред

  • GELU рд╕рдХреНрд░рд┐рдпрдг: рдЧреИрд░-рд░реЗрдЦреАрдпрддрд╛ рд▓рд╛рдЧреВ рдХрд░рддрд╛ рд╣реИред

  • рджреВрд╕рд░реА рд░реИрдЦрд┐рдХ рдкрд░рдд: рдЖрдпрд╛рдо рдХреЛ рдлрд┐рд░ рд╕реЗ emb_dim рддрдХ рдХрдо рдХрд░рддрд╛ рд╣реИред

рдЬреИрд╕рд╛ рдХрд┐ рдЖрдк рджреЗрдЦ рд╕рдХрддреЗ рд╣реИрдВ, рдлреАрдб рдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯрд╡рд░реНрдХ 3 рдкрд░рддреЛрдВ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддрд╛ рд╣реИред рдкрд╣рд▓реА рдПрдХ рд░реИрдЦрд┐рдХ рдкрд░рдд рд╣реИ рдЬреЛ рд░реИрдЦрд┐рдХ рд╡рдЬрди (рдореЙрдбрд▓ рдХреЗ рдЕрдВрджрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдкреИрд░рд╛рдореАрдЯрд░) рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЖрдпрд╛рдореЛрдВ рдХреЛ 4 рд╕реЗ рдЧреБрдгрд╛ рдХрд░реЗрдЧреАред рдлрд┐рд░, GELU рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдЙрдкрдпреЛрдЧ рдЙрди рд╕рднреА рдЖрдпрд╛рдореЛрдВ рдореЗрдВ рдЧреИрд░-рд░реЗрдЦреАрдп рднрд┐рдиреНрдирддрд╛рдУрдВ рдХреЛ рд▓рд╛рдЧреВ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ рддрд╛рдХрд┐ рд╕рдореГрджреНрдз рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рдХреЛ рдХреИрдкреНрдЪрд░ рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХреЗ рдФрд░ рдЕрдВрддрддрдГ рдПрдХ рдФрд░ рд░реИрдЦрд┐рдХ рдкрд░рдд рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдореВрд▓ рдЖрдпрд╛рдо рдХреЗ рдЖрдХрд╛рд░ рдкрд░ рд╡рд╛рдкрд╕ рд▓рд╛рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

рдорд▓реНрдЯреА-рд╣реЗрдб рдзреНрдпрд╛рди рддрдВрддреНрд░

рдпрд╣ рдкрд╣рд▓реЗ рдХреЗ рдЦрдВрдб рдореЗрдВ рдкрд╣рд▓реЗ рд╣реА рд╕рдордЭрд╛рдпрд╛ рдЧрдпрд╛ рдерд╛ред

рдЙрджреНрджреЗрд╢реНрдп рдФрд░ рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛

  • рдорд▓реНрдЯреА-рд╣реЗрдб рд╕реЗрд▓реНрдл-рдЕрдЯреЗрдВрд╢рди: рдореЙрдбрд▓ рдХреЛ рдЯреЛрдХрди рдХреЛ рдПрдиреНрдХреЛрдб рдХрд░рддреЗ рд╕рдордп рдЗрдирдкреБрдЯ рдЕрдиреБрдХреНрд░рдо рдХреЗ рднреАрддрд░ рд╡рд┐рднрд┐рдиреНрди рдкреЛрдЬреАрд╢рдиреЛрдВ рдкрд░ рдзреНрдпрд╛рди рдХреЗрдВрджреНрд░рд┐рдд рдХрд░рдиреЗ рдХреА рдЕрдиреБрдорддрд┐ рджреЗрддрд╛ рд╣реИред

  • рдореБрдЦреНрдп рдШрдЯрдХ:

  • рдХреНрд╡реЗрд░реА, рдХреА, рд╡реИрд▓реНрдпреВ: рдЗрдирдкреБрдЯ рдХреЗ рд░реИрдЦрд┐рдХ рдкреНрд░рдХреНрд╖рд┐рдкреНрддрд┐рдпрд╛рдБ, рдЬреЛ рдзреНрдпрд╛рди рд╕реНрдХреЛрд░ рдХреА рдЧрдгрдирд╛ рдХреЗ рд▓рд┐рдП рдЙрдкрдпреЛрдЧ рдХреА рдЬрд╛рддреА рд╣реИрдВред

  • рд╣реЗрдбреНрд╕: рд╕рдорд╛рдирд╛рдВрддрд░ рдореЗрдВ рдЪрд▓рдиреЗ рд╡рд╛рд▓реЗ рдХрдИ рдзреНрдпрд╛рди рддрдВрддреНрд░ (num_heads), рдкреНрд░рддреНрдпреЗрдХ рдХреЗ рд╕рд╛рде рдПрдХ рдХрдо рдЖрдпрд╛рдо (head_dim)ред

  • рдзреНрдпрд╛рди рд╕реНрдХреЛрд░: рдХреНрд╡реЗрд░реА рдФрд░ рдХреАрдЬрд╝ рдХреЗ рдбреЙрдЯ рдЙрддреНрдкрд╛рдж рдХреЗ рд░реВрдк рдореЗрдВ рдЧрдгрдирд╛ рдХреА рдЬрд╛рддреА рд╣реИ, рд╕реНрдХреЗрд▓ рдФрд░ рдорд╛рд╕реНрдХ рдХреА рдЬрд╛рддреА рд╣реИред

  • рдорд╛рд╕реНрдХрд┐рдВрдЧ: рднрд╡рд┐рд╖реНрдп рдХреЗ рдЯреЛрдХрдиреЛрдВ рдкрд░ рдзреНрдпрд╛рди рдХреЗрдВрджреНрд░рд┐рдд рдХрд░рдиреЗ рд╕реЗ рд░реЛрдХрдиреЗ рдХреЗ рд▓рд┐рдП рдПрдХ рдХрд╛рд░рдгрд╛рддреНрдордХ рдорд╛рд╕реНрдХ рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ (GPT рдЬреИрд╕реЗ рдСрдЯреЛрд░рд┐рдЧреНрд░реЗрд╕рд┐рд╡ рдореЙрдбрд▓реЛрдВ рдХреЗ рд▓рд┐рдП рдорд╣рддреНрд╡рдкреВрд░реНрдг)ред

  • рдзреНрдпрд╛рди рд╡рдЬрди: рдорд╛рд╕реНрдХ рдХрд┐рдП рдЧрдП рдФрд░ рд╕реНрдХреЗрд▓ рдХрд┐рдП рдЧрдП рдзреНрдпрд╛рди рд╕реНрдХреЛрд░ рдХрд╛ рд╕реЙрдлреНрдЯрдореИрдХреНрд╕ред

  • рд╕рдВрджрд░реНрдн рд╡реЗрдХреНрдЯрд░: рдзреНрдпрд╛рди рд╡рдЬрди рдХреЗ рдЕрдиреБрд╕рд╛рд░ рдорд╛рдиреЛрдВ рдХрд╛ рднрд╛рд░рд┐рдд рдпреЛрдЧред

  • рдЖрдЙрдЯрдкреБрдЯ рдкреНрд░рдХреНрд╖рд┐рдкреНрддрд┐: рд╕рднреА рд╣реЗрдбреНрд╕ рдХреЗ рдЖрдЙрдЯрдкреБрдЯ рдХреЛ рд╕рдВрдпреЛрдЬрд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рд░реИрдЦрд┐рдХ рдкрд░рддред

рдЗрд╕ рдиреЗрдЯрд╡рд░реНрдХ рдХрд╛ рд▓рдХреНрд╖реНрдп рдПрдХ рд╣реА рд╕рдВрджрд░реНрдн рдореЗрдВ рдЯреЛрдХрдиреЛрдВ рдХреЗ рдмреАрдЪ рд╕рдВрдмрдВрдзреЛрдВ рдХреЛ рдЦреЛрдЬрдирд╛ рд╣реИред рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рдЯреЛрдХрдиреЛрдВ рдХреЛ рд╡рд┐рднрд┐рдиреНрди рд╣реЗрдбреНрд╕ рдореЗрдВ рд╡рд┐рднрд╛рдЬрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ рддрд╛рдХрд┐ рдУрд╡рд░рдлрд┐рдЯрд┐рдВрдЧ рдХреЛ рд░реЛрдХрд╛ рдЬрд╛ рд╕рдХреЗ, рд╣рд╛рд▓рд╛рдВрдХрд┐ рдкреНрд░рддреНрдпреЗрдХ рд╣реЗрдб рдореЗрдВ рдкрд╛рдП рдЧрдП рдЕрдВрддрд┐рдо рд╕рдВрдмрдВрдзреЛрдВ рдХреЛ рдЗрд╕ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдЕрдВрдд рдореЗрдВ рд╕рдВрдпреЛрдЬрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдПрдХ рдХрд╛рд░рдгрд╛рддреНрдордХ рдорд╛рд╕реНрдХ рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ рддрд╛рдХрд┐ рдмрд╛рдж рдХреЗ рдЯреЛрдХрдиреЛрдВ рдХреЛ рдПрдХ рдЯреЛрдХрди рдХреЗ рд▓рд┐рдП рд╡рд┐рд╢рд┐рд╖реНрдЯ рд╕рдВрдмрдВрдзреЛрдВ рдХреЛ рджреЗрдЦрддреЗ рд╕рдордп рдзреНрдпрд╛рди рдореЗрдВ рди рд▓рд┐рдпрд╛ рдЬрд╛рдП рдФрд░ рдХреБрдЫ рдбреНрд░реЙрдкрдЖрдЙрдЯ рднреА рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ рддрд╛рдХрд┐ рдУрд╡рд░рдлрд┐рдЯрд┐рдВрдЧ рдХреЛ рд░реЛрдХрд╛ рдЬрд╛ рд╕рдХреЗред

рдкрд░рдд рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг

# From https://github.com/rasbt/LLMs-from-scratch/tree/main/ch04
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5 # Prevent division by zero during normalization.
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))

def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift

рдЙрджреНрджреЗрд╢реНрдп рдФрд░ рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛

  • рд▓реЗрдпрд░ рдиреЙрд░реНрдорд▓рд╛рдЗрдЬреЗрд╢рди: рдПрдХ рддрдХрдиреАрдХ рдЬреЛ рдмреИрдЪ рдореЗрдВ рдкреНрд░рддреНрдпреЗрдХ рд╡реНрдпрдХреНрддрд┐рдЧрдд рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП рд╡рд┐рд╢реЗрд╖рддрд╛рдУрдВ (рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдЖрдпрд╛рдо) рдХреЗ рдмреАрдЪ рдЗрдирдкреБрдЯ рдХреЛ рд╕рд╛рдорд╛рдиреНрдп рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЙрдкрдпреЛрдЧ рдХреА рдЬрд╛рддреА рд╣реИред

  • рдШрдЯрдХ:

  • eps: рдПрдХ рдЫреЛрдЯрд╛ рд╕реНрдерд┐рд░рд╛рдВрдХ (1e-5) рдЬреЛ рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг рдХреЗ рджреМрд░рд╛рди рд╢реВрдиреНрдп рд╕реЗ рд╡рд┐рднрд╛рдЬрди рдХреЛ рд░реЛрдХрдиреЗ рдХреЗ рд▓рд┐рдП рд╡реИрд░рд┐рдПрдВрд╕ рдореЗрдВ рдЬреЛрдбрд╝рд╛ рдЬрд╛рддрд╛ рд╣реИред

  • scale рдФрд░ shift: рд╕реАрдЦрдиреЗ рдпреЛрдЧреНрдп рдкреИрд░рд╛рдореАрдЯрд░ (nn.Parameter) рдЬреЛ рдореЙрдбрд▓ рдХреЛ рд╕рд╛рдорд╛рдиреНрдпреАрдХреГрдд рдЖрдЙрдЯрдкреБрдЯ рдХреЛ рд╕реНрдХреЗрд▓ рдФрд░ рд╢рд┐рдлреНрдЯ рдХрд░рдиреЗ рдХреА рдЕрдиреБрдорддрд┐ рджреЗрддреЗ рд╣реИрдВред рдЗрдиреНрд╣реЗрдВ рдХреНрд░рдорд╢рдГ рдПрдХ рдФрд░ рд╢реВрдиреНрдп рд╕реЗ рдкреНрд░рд╛рд░рдВрдн рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

  • рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг рдкреНрд░рдХреНрд░рд┐рдпрд╛:

  • рдореАрди рдХреА рдЧрдгрдирд╛ (mean): рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдЖрдпрд╛рдо (dim=-1) рдХреЗ рдкрд╛рд░ рдЗрдирдкреБрдЯ x рдХрд╛ рдФрд╕рдд рдирд┐рдХрд╛рд▓рддрд╛ рд╣реИ, рдкреНрд░рд╕рд╛рд░ рдХреЗ рд▓рд┐рдП рдЖрдпрд╛рдо рдХреЛ рдмрдирд╛рдП рд░рдЦрддреЗ рд╣реБрдП (keepdim=True)ред

  • рд╡реИрд░рд┐рдПрдВрд╕ рдХреА рдЧрдгрдирд╛ (var): рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдЖрдпрд╛рдо рдХреЗ рдкрд╛рд░ x рдХрд╛ рд╡реИрд░рд┐рдПрдВрд╕ рдирд┐рдХрд╛рд▓рддрд╛ рд╣реИ, рдЖрдпрд╛рдо рдХреЛ рднреА рдмрдирд╛рдП рд░рдЦрддреЗ рд╣реБрдПред unbiased=False рдкреИрд░рд╛рдореАрдЯрд░ рдпрд╣ рд╕реБрдирд┐рд╢реНрдЪрд┐рдд рдХрд░рддрд╛ рд╣реИ рдХрд┐ рд╡реИрд░рд┐рдПрдВрд╕ рдкреВрд░реНрд╡рд╛рдЧреНрд░рд╣рд┐рдд рдЕрдиреБрдорд╛рдирдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдирд┐рдХрд╛рд▓рд╛ рдЬрд╛рддрд╛ рд╣реИ (N рдХреЗ рдмрдЬрд╛рдп N-1 рд╕реЗ рд╡рд┐рднрд╛рдЬрд┐рдд рдХрд░рдирд╛), рдЬреЛ рд╡рд┐рд╢реЗрд╖рддрд╛рдУрдВ рдХреЗ рдмрдЬрд╛рдп рдирдореВрдиреЛрдВ рдХреЗ рдКрдкрд░ рд╕рд╛рдорд╛рдиреНрдпреАрдХреГрдд рдХрд░рддреЗ рд╕рдордп рдЙрдкрдпреБрдХреНрдд рд╣реИред

  • рдиреЙрд░реНрдорд▓рд╛рдЗрдЬ (norm_x): x рд╕реЗ рдореАрди рдШрдЯрд╛рддрд╛ рд╣реИ рдФрд░ рд╡реИрд░рд┐рдПрдВрд╕ рдХреЗ рд╡рд░реНрдЧрдореВрд▓ рдХреЗ рд╕рд╛рде eps рдХреЛ рдЬреЛрдбрд╝рдХрд░ рд╡рд┐рднрд╛рдЬрд┐рдд рдХрд░рддрд╛ рд╣реИред

  • рд╕реНрдХреЗрд▓ рдФрд░ рд╢рд┐рдлреНрдЯ: рд╕рд╛рдорд╛рдиреНрдпреАрдХреГрдд рдЖрдЙрдЯрдкреБрдЯ рдкрд░ рд╕реАрдЦрдиреЗ рдпреЛрдЧреНрдп scale рдФрд░ shift рдкреИрд░рд╛рдореАрдЯрд░ рд▓рд╛рдЧреВ рдХрд░рддрд╛ рд╣реИред

рд▓рдХреНрд╖реНрдп рдпрд╣ рд╕реБрдирд┐рд╢реНрдЪрд┐рдд рдХрд░рдирд╛ рд╣реИ рдХрд┐ рдПрдХ рд╣реА рдЯреЛрдХрди рдХреЗ рд╕рднреА рдЖрдпрд╛рдореЛрдВ рдореЗрдВ 0 рдХрд╛ рдФрд╕рдд рдФрд░ 1 рдХрд╛ рд╡реИрд░рд┐рдПрдВрд╕ рд╣реЛред рдЗрд╕рдХрд╛ рд▓рдХреНрд╖реНрдп рдЧрд╣рд░реЗ рдиреНрдпреВрд░рд▓ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЛ рд╕реНрдерд┐рд░ рдХрд░рдирд╛ рд╣реИ, рдЬреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдкреИрд░рд╛рдореАрдЯрд░ рдХреЗ рдЕрджреНрдпрддрди рдХреЗ рдХрд╛рд░рдг рдиреЗрдЯрд╡рд░реНрдХ рд╕рдХреНрд░рд┐рдпрдг рдХреЗ рд╡рд┐рддрд░рдг рдореЗрдВ рдкрд░рд┐рд╡рд░реНрддрди рдХреЛ рд╕рдВрджрд░реНрднрд┐рдд рдХрд░рддрд╛ рд╣реИред

рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ

рдЖрдХреГрддрд┐рдпреЛрдВ рдХреЛ рдореИрдЯреНрд░рд┐рд╕реЗрд╕ рдХреЗ рдЖрдХрд╛рд░ рдХреЛ рдмреЗрд╣рддрд░ рд╕рдордЭрдиреЗ рдХреЗ рд▓рд┐рдП рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рдХреЗ рд░реВрдк рдореЗрдВ рдЬреЛрдбрд╝рд╛ рдЧрдпрд╛ рд╣реИ:

# From https://github.com/rasbt/LLMs-from-scratch/tree/main/ch04

class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"]
)
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

def forward(self, x):
# x shape: (batch_size, seq_len, emb_dim)

# Shortcut connection for attention block
shortcut = x  # shape: (batch_size, seq_len, emb_dim)
x = self.norm1(x)  # shape remains (batch_size, seq_len, emb_dim)
x = self.att(x)    # shape: (batch_size, seq_len, emb_dim)
x = self.drop_shortcut(x)  # shape remains (batch_size, seq_len, emb_dim)
x = x + shortcut   # shape: (batch_size, seq_len, emb_dim)

# Shortcut connection for feedforward block
shortcut = x       # shape: (batch_size, seq_len, emb_dim)
x = self.norm2(x)  # shape remains (batch_size, seq_len, emb_dim)
x = self.ff(x)     # shape: (batch_size, seq_len, emb_dim)
x = self.drop_shortcut(x)  # shape remains (batch_size, seq_len, emb_dim)
x = x + shortcut   # shape: (batch_size, seq_len, emb_dim)

return x  # Output shape: (batch_size, seq_len, emb_dim)

рдЙрджреНрджреЗрд╢реНрдп рдФрд░ рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛

  • рдкрд░рддреЛрдВ рдХреА рд╕рдВрд░рдЪрдирд╛: рдорд▓реНрдЯреА-рд╣реЗрдб рдзреНрдпрд╛рди, рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯрд╡рд░реНрдХ, рдкрд░рдд рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг, рдФрд░ рдЕрд╡рд╢рд┐рд╖реНрдЯ рдХрдиреЗрдХреНрд╢рди рдХреЛ рдЬреЛрдбрд╝рддрд╛ рд╣реИред

  • рдкрд░рдд рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг: рд╕реНрдерд┐рд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдзреНрдпрд╛рди рдФрд░ рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдкрд░рддреЛрдВ рд╕реЗ рдкрд╣рд▓реЗ рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

  • рдЕрд╡рд╢рд┐рд╖реНрдЯ рдХрдиреЗрдХреНрд╢рди (рд╢реЙрд░реНрдЯрдХрдЯ): рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ рдкреНрд░рд╡рд╛рд╣ рдореЗрдВ рд╕реБрдзрд╛рд░ рдХрд░рдиреЗ рдФрд░ рдЧрд╣рд░реЗ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЛ рд╕рдХреНрд╖рдо рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдПрдХ рдкрд░рдд рдХреЗ рдЗрдирдкреБрдЯ рдХреЛ рдЗрд╕рдХреЗ рдЖрдЙрдЯрдкреБрдЯ рдореЗрдВ рдЬреЛрдбрд╝рддрд╛ рд╣реИред

  • рдбреНрд░реЙрдкрдЖрдЙрдЯ: рдирд┐рдпрдорд┐рддреАрдХрд░рдг рдХреЗ рд▓рд┐рдП рдзреНрдпрд╛рди рдФрд░ рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдкрд░рддреЛрдВ рдХреЗ рдмрд╛рдж рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

рдЪрд░рдг-рджрд░-рдЪрд░рдг рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛

  1. рдкрд╣рд▓рд╛ рдЕрд╡рд╢рд┐рд╖реНрдЯ рдкрде (рд╕реЗрд▓реНрдл-рдЕрдЯреЗрдВрд╢рди):

  • рдЗрдирдкреБрдЯ (shortcut): рдЕрд╡рд╢рд┐рд╖реНрдЯ рдХрдиреЗрдХреНрд╢рди рдХреЗ рд▓рд┐рдП рдореВрд▓ рдЗрдирдкреБрдЯ рдХреЛ рд╕рд╣реЗрдЬреЗрдВред

  • рд▓реЗрдпрд░ рдиреЙрд░реНрдо (norm1): рдЗрдирдкреБрдЯ рдХреЛ рд╕рд╛рдорд╛рдиреНрдпреАрдХреГрдд рдХрд░реЗрдВред

  • рдорд▓реНрдЯреА-рд╣реЗрдб рдЕрдЯреЗрдВрд╢рди (att): рд╕реЗрд▓реНрдл-рдЕрдЯреЗрдВрд╢рди рд▓рд╛рдЧреВ рдХрд░реЗрдВред

  • рдбреНрд░реЙрдкрдЖрдЙрдЯ (drop_shortcut): рдирд┐рдпрдорд┐рддреАрдХрд░рдг рдХреЗ рд▓рд┐рдП рдбреНрд░реЙрдкрдЖрдЙрдЯ рд▓рд╛рдЧреВ рдХрд░реЗрдВред

  • рдЕрд╡рд╢рд┐рд╖реНрдЯ рдЬреЛрдбрд╝реЗрдВ (x + shortcut): рдореВрд▓ рдЗрдирдкреБрдЯ рдХреЗ рд╕рд╛рде рдорд┐рд▓рд╛рдПрдВред

  1. рджреВрд╕рд░рд╛ рдЕрд╡рд╢рд┐рд╖реНрдЯ рдкрде (рдлреАрдбрдлреЙрд░рд╡рд░реНрдб):

  • рдЗрдирдкреБрдЯ (shortcut): рдЕрдЧрд▓реЗ рдЕрд╡рд╢рд┐рд╖реНрдЯ рдХрдиреЗрдХреНрд╢рди рдХреЗ рд▓рд┐рдП рдЕрдкрдбреЗрдЯреЗрдб рдЗрдирдкреБрдЯ рдХреЛ рд╕рд╣реЗрдЬреЗрдВред

  • рд▓реЗрдпрд░ рдиреЙрд░реНрдо (norm2): рдЗрдирдкреБрдЯ рдХреЛ рд╕рд╛рдорд╛рдиреНрдпреАрдХреГрдд рдХрд░реЗрдВред

  • рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯрд╡рд░реНрдХ (ff): рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдкрд░рд┐рд╡рд░реНрддрди рд▓рд╛рдЧреВ рдХрд░реЗрдВред

  • рдбреНрд░реЙрдкрдЖрдЙрдЯ (drop_shortcut): рдбреНрд░реЙрдкрдЖрдЙрдЯ рд▓рд╛рдЧреВ рдХрд░реЗрдВред

  • рдЕрд╡рд╢рд┐рд╖реНрдЯ рдЬреЛрдбрд╝реЗрдВ (x + shortcut): рдкрд╣рд▓реЗ рдЕрд╡рд╢рд┐рд╖реНрдЯ рдкрде рд╕реЗ рдЗрдирдкреБрдЯ рдХреЗ рд╕рд╛рде рдорд┐рд▓рд╛рдПрдВред

рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ рд╕рднреА рдиреЗрдЯрд╡рд░реНрдХ рдХреЛ рдПрдХ рд╕рд╛рде рд╕рдореВрд╣рд┐рдд рдХрд░рддрд╛ рд╣реИ рдФрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╕реНрдерд┐рд░рддрд╛ рдФрд░ рдкрд░рд┐рдгрд╛рдореЛрдВ рдореЗрдВ рд╕реБрдзрд╛рд░ рдХреЗ рд▓рд┐рдП рдХреБрдЫ рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг рдФрд░ рдбреНрд░реЙрдкрдЖрдЙрдЯ рд▓рд╛рдЧреВ рдХрд░рддрд╛ рд╣реИред рдзреНрдпрд╛рди рджреЗрдВ рдХрд┐ рдбреНрд░реЙрдкрдЖрдЙрдЯ рдкреНрд░рддреНрдпреЗрдХ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдЙрдкрдпреЛрдЧ рдХреЗ рдмрд╛рдж рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ рдЬрдмрдХрд┐ рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг рдкрд╣рд▓реЗ рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рдпрд╣ рд╢реЙрд░реНрдЯрдХрдЯ рдХрд╛ рднреА рдЙрдкрдпреЛрдЧ рдХрд░рддрд╛ рд╣реИ рдЬрд┐рд╕рдореЗрдВ рдПрдХ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдЖрдЙрдЯрдкреБрдЯ рдХреЛ рдЗрд╕рдХреЗ рдЗрдирдкреБрдЯ рдХреЗ рд╕рд╛рде рдЬреЛрдбрд╝рдирд╛ рд╢рд╛рдорд┐рд▓ рд╣реИред рдпрд╣ рд╕реБрдирд┐рд╢реНрдЪрд┐рдд рдХрд░рдХреЗ рд╡реИрдирд┐рд╢рд┐рдВрдЧ рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ рд╕рдорд╕реНрдпрд╛ рдХреЛ рд░реЛрдХрдиреЗ рдореЗрдВ рдорджрдж рдХрд░рддрд╛ рд╣реИ рдХрд┐ рдкреНрд░рд╛рд░рдВрднрд┐рдХ рдкрд░рддреЗрдВ "рдЬрд┐рддрдирд╛" рдпреЛрдЧрджрд╛рди рдХрд░рддреА рд╣реИрдВ рдЬрд┐рддрдирд╛ рдХрд┐ рдЕрдВрддрд┐рдо рдкрд░рддреЗрдВред

GPTModel

рдЖрдХреГрддрд┐рдпреЛрдВ рдХреЛ рдореИрдЯреНрд░рд┐рд╕реЗрд╕ рдХреЗ рдЖрдХрд╛рд░ рдХреЛ рдмреЗрд╣рддрд░ рд╕рдордЭрдиреЗ рдХреЗ рд▓рд┐рдП рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рдХреЗ рд░реВрдк рдореЗрдВ рдЬреЛрдбрд╝рд╛ рдЧрдпрд╛ рд╣реИ:

# From https://github.com/rasbt/LLMs-from-scratch/tree/main/ch04
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
# shape: (vocab_size, emb_dim)

self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
# shape: (context_length, emb_dim)

self.drop_emb = nn.Dropout(cfg["drop_rate"])

self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
# Stack of TransformerBlocks

self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
# shape: (emb_dim, vocab_size)

def forward(self, in_idx):
# in_idx shape: (batch_size, seq_len)
batch_size, seq_len = in_idx.shape

# Token embeddings
tok_embeds = self.tok_emb(in_idx)
# shape: (batch_size, seq_len, emb_dim)

# Positional embeddings
pos_indices = torch.arange(seq_len, device=in_idx.device)
# shape: (seq_len,)
pos_embeds = self.pos_emb(pos_indices)
# shape: (seq_len, emb_dim)

# Add token and positional embeddings
x = tok_embeds + pos_embeds  # Broadcasting over batch dimension
# x shape: (batch_size, seq_len, emb_dim)

x = self.drop_emb(x)  # Dropout applied
# x shape remains: (batch_size, seq_len, emb_dim)

x = self.trf_blocks(x)  # Pass through Transformer blocks
# x shape remains: (batch_size, seq_len, emb_dim)

x = self.final_norm(x)  # Final LayerNorm
# x shape remains: (batch_size, seq_len, emb_dim)

logits = self.out_head(x)  # Project to vocabulary size
# logits shape: (batch_size, seq_len, vocab_size)

return logits  # Output shape: (batch_size, seq_len, vocab_size)

рдЙрджреНрджреЗрд╢реНрдп рдФрд░ рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛

  • Embedding Layers:

  • Token Embeddings (tok_emb): рдЯреЛрдХрди рдЕрдиреБрдХреНрд░рдорд╛рдВрдХ рдХреЛ рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдореЗрдВ рдкрд░рд┐рд╡рд░реНрддрд┐рдд рдХрд░рддрд╛ рд╣реИред рдпрд╛рдж рджрд┐рд▓рд╛рдиреЗ рдХреЗ рд▓рд┐рдП, рдпреЗ рд╢рдмреНрджрд╛рд╡рд▓реА рдореЗрдВ рдкреНрд░рддреНрдпреЗрдХ рдЯреЛрдХрди рдХреЗ рдкреНрд░рддреНрдпреЗрдХ рдЖрдпрд╛рдо рдХреЛ рджрд┐рдП рдЧрдП рд╡рдЬрди рд╣реИрдВред

  • Positional Embeddings (pos_emb): рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдореЗрдВ рд╕реНрдерд┐рддрд┐ рд╕рдВрдмрдВрдзреА рдЬрд╛рдирдХрд╛рд░реА рдЬреЛрдбрд╝рддрд╛ рд╣реИ рддрд╛рдХрд┐ рдЯреЛрдХрдиреЛрдВ рдХреЗ рдХреНрд░рдо рдХреЛ рдХреИрдкреНрдЪрд░ рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХреЗред рдпрд╛рдж рджрд┐рд▓рд╛рдиреЗ рдХреЗ рд▓рд┐рдП, рдпреЗ рдЯреЛрдХрди рдХреЛ рдЙрд╕рдХреЗ рдкрд╛рда рдореЗрдВ рд╕реНрдерд┐рддрд┐ рдХреЗ рдЕрдиреБрд╕рд╛рд░ рджрд┐рдП рдЧрдП рд╡рдЬрди рд╣реИрдВред

  • Dropout (drop_emb): рдирд┐рдпрдорд┐рддреАрдХрд░рдг рдХреЗ рд▓рд┐рдП рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдкрд░ рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

  • Transformer Blocks (trf_blocks): рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдХреЛ рдкреНрд░реЛрд╕реЗрд╕ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП n_layers рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХреЛрдВ рдХрд╛ рд╕реНрдЯреИрдХред

  • Final Normalization (final_norm): рдЖрдЙрдЯрдкреБрдЯ рд▓реЗрдпрд░ рд╕реЗ рдкрд╣рд▓реЗ рд▓реЗрдпрд░ рдиреЙрд░реНрдорд▓рд╛рдЗрдЬреЗрд╢рдиред

  • Output Layer (out_head): рдЕрдВрддрд┐рдо рдЫрд┐рдкреЗ рд╣реБрдП рд░рд╛рдЬреНрдпреЛрдВ рдХреЛ рд╢рдмреНрджрд╛рд╡рд▓реА рдХреЗ рдЖрдХрд╛рд░ рдореЗрдВ рдкреНрд░рдХреНрд╖рд┐рдкреНрдд рдХрд░рддрд╛ рд╣реИ рддрд╛рдХрд┐ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХреЗ рд▓рд┐рдП рд▓реЙрдЬрд┐рдЯреНрд╕ рдЙрддреНрдкрдиреНрди рд╣реЛ рд╕рдХреЗрдВред

рдЗрд╕ рд╡рд░реНрдЧ рдХрд╛ рд▓рдХреНрд╖реНрдп рдЕрдиреБрдХреНрд░рдо рдореЗрдВ рдЕрдЧрд▓реЗ рдЯреЛрдХрди рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рд╕рднреА рдЕрдиреНрдп рдЙрд▓реНрд▓реЗрдЦрд┐рдд рдиреЗрдЯрд╡рд░реНрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛ рд╣реИ, рдЬреЛ рдкрд╛рда рдирд┐рд░реНрдорд╛рдг рдЬреИрд╕реЗ рдХрд╛рд░реНрдпреЛрдВ рдХреЗ рд▓рд┐рдП рдореМрд▓рд┐рдХ рд╣реИред

рдзреНрдпрд╛рди рджреЗрдВ рдХрд┐ рдпрд╣ рдЬрд┐рддрдиреЗ рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ рдирд┐рд░реНрджрд┐рд╖реНрдЯ рд╣реИрдВ рдЙрддрдиреЗ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдЧрд╛ рдФрд░ рдкреНрд░рддреНрдпреЗрдХ рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ рдПрдХ рдорд▓реНрдЯреА-рд╣реЗрдб рдЕрдЯреЗрдВрд╢рди рдиреЗрдЯ, рдПрдХ рдлреАрдб рдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯ рдФрд░ рдХрдИ рдиреЙрд░реНрдорд▓рд╛рдЗрдЬреЗрд╢рди рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░ рд░рд╣рд╛ рд╣реИред рдЗрд╕рд▓рд┐рдП рдпрджрд┐ 12 рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ, рддреЛ рдЗрд╕реЗ 12 рд╕реЗ рдЧреБрдгрд╛ рдХрд░реЗрдВред

рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рдПрдХ рдиреЙрд░реНрдорд▓рд╛рдЗрдЬреЗрд╢рди рд▓реЗрдпрд░ рдЖрдЙрдЯрдкреБрдЯ рд╕реЗ рдкрд╣рд▓реЗ рдЬреЛрдбрд╝реА рдЬрд╛рддреА рд╣реИ рдФрд░ рдЕрдВрдд рдореЗрдВ рдкрд░рд┐рдгрд╛рдореЛрдВ рдХреЛ рдЙрдЪрд┐рдд рдЖрдпрд╛рдореЛрдВ рдХреЗ рд╕рд╛рде рдкреНрд░рд╛рдкреНрдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдПрдХ рдЕрдВрддрд┐рдо рд░реИрдЦрд┐рдХ рд▓реЗрдпрд░ рд▓рд╛рдЧреВ рдХреА рдЬрд╛рддреА рд╣реИред рдзреНрдпрд╛рди рджреЗрдВ рдХрд┐ рдкреНрд░рддреНрдпреЗрдХ рдЕрдВрддрд┐рдо рд╡реЗрдХреНрдЯрд░ рдХрд╛ рдЖрдХрд╛рд░ рдЙрдкрдпреЛрдЧ рдХреА рдЧрдИ рд╢рдмреНрджрд╛рд╡рд▓реА рдХреЗ рдЖрдХрд╛рд░ рдХреЗ рдмрд░рд╛рдмрд░ рд╣реИред рдЗрд╕рдХрд╛ рдХрд╛рд░рдг рдпрд╣ рд╣реИ рдХрд┐ рдпрд╣ рд╢рдмреНрджрд╛рд╡рд▓реА рдХреЗ рднреАрддрд░ рд╕рдВрднрд╛рд╡рд┐рдд рдЯреЛрдХрди рдХреЗ рд▓рд┐рдП рдПрдХ рд╕рдВрднрд╛рд╡рдирд╛ рдкреНрд░рд╛рдкреНрдд рдХрд░рдиреЗ рдХреА рдХреЛрд╢рд┐рд╢ рдХрд░ рд░рд╣рд╛ рд╣реИред

рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдкреИрд░рд╛рдореАрдЯрд░ рдХреА рд╕рдВрдЦреНрдпрд╛

GPT рд╕рдВрд░рдЪрдирд╛ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рдиреЗ рдХреЗ рдмрд╛рдж, рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдкреИрд░рд╛рдореАрдЯрд░ рдХреА рд╕рдВрдЦреНрдпрд╛ рдкрддрд╛ рд▓рдЧрд╛рдирд╛ рд╕рдВрднрд╡ рд╣реИ:

GPT_CONFIG_124M = {
"vocab_size": 50257,    # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768,         # Embedding dimension
"n_heads": 12,          # Number of attention heads
"n_layers": 12,         # Number of layers
"drop_rate": 0.1,       # Dropout rate
"qkv_bias": False       # Query-Key-Value bias
}

model = GPTModel(GPT_CONFIG_124M)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")
# Total number of parameters: 163,009,536

рдЪрд░рдг-рджрд░-рдЪрд░рдг рдЧрдгрдирд╛

1. рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдкрд░рддреЗрдВ: рдЯреЛрдХрди рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдФрд░ рд╕реНрдерд┐рддрд┐ рдПрдореНрдмреЗрдбрд┐рдВрдЧ

  • рдкрд░рдд: nn.Embedding(vocab_size, emb_dim)

  • рдкреИрд░рд╛рдореАрдЯрд░: vocab_size * emb_dim

token_embedding_params = 50257 * 768 = 38,597,376
  • рд▓реЗрдпрд░: nn.Embedding(context_length, emb_dim)

  • рдкреИрд░рд╛рдореАрдЯрд░реНрд╕: context_length * emb_dim

position_embedding_params = 1024 * 768 = 786,432

рдХреБрд▓ рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдкреИрд░рд╛рдореАрдЯрд░

embedding_params = token_embedding_params + position_embedding_params
embedding_params = 38,597,376 + 786,432 = 39,383,808

2. Transformer Blocks

рдпрд╣рд╛рдВ 12 рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ рд╣реИрдВ, рдЗрд╕рд▓рд┐рдП рд╣рдо рдПрдХ рдмреНрд▓реЙрдХ рдХреЗ рд▓рд┐рдП рдкреИрд░рд╛рдореАрдЯрд░ рдХреА рдЧрдгрдирд╛ рдХрд░реЗрдВрдЧреЗ рдФрд░ рдлрд┐рд░ 12 рд╕реЗ рдЧреБрдгрд╛ рдХрд░реЗрдВрдЧреЗред

рдкреНрд░рддреНрдпреЗрдХ рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ рдХреЗ рд▓рд┐рдП рдкреИрд░рд╛рдореАрдЯрд░

a. рдорд▓реНрдЯреА-рд╣реЗрдб рдЕрдЯреЗрдВрд╢рди

  • рдШрдЯрдХ:

  • рдХреНрд╡реЗрд░реА рд▓реАрдирд┐рдпрд░ рд▓реЗрдпрд░ (W_query): nn.Linear(emb_dim, emb_dim, bias=False)

  • рдХреА рд▓реАрдирд┐рдпрд░ рд▓реЗрдпрд░ (W_key): nn.Linear(emb_dim, emb_dim, bias=False)

  • рд╡реИрд▓реНрдпреВ рд▓реАрдирд┐рдпрд░ рд▓реЗрдпрд░ (W_value): nn.Linear(emb_dim, emb_dim, bias=False)

  • рдЖрдЙрдЯрдкреБрдЯ рдкреНрд░реЛрдЬреЗрдХреНрд╢рди (out_proj): nn.Linear(emb_dim, emb_dim)

  • рдЧрдгрдирд╛рдПрдБ:

  • W_query, W_key, W_value рдореЗрдВ рд╕реЗ рдкреНрд░рддреНрдпреЗрдХ:

qkv_params = emb_dim * emb_dim = 768 * 768 = 589,824

рдЪреВрдВрдХрд┐ рдРрд╕реА рддреАрди рд▓реЗрдпрд░ рд╣реИрдВ:

total_qkv_params = 3 * qkv_params = 3 * 589,824 = 1,769,472
  • рдЖрдЙрдЯрдкреБрдЯ рдкреНрд░реЛрдЬреЗрдХреНрд╢рди (out_proj):

out_proj_params = (emb_dim * emb_dim) + emb_dim = (768 * 768) + 768 = 589,824 + 768 = 590,592
  • рдХреБрд▓ рдорд▓реНрдЯреА-рд╣реЗрдб рдЕрдЯреЗрдВрд╢рди рдкреИрд░рд╛рдореАрдЯрд░:

mha_params = total_qkv_params + out_proj_params
mha_params = 1,769,472 + 590,592 = 2,360,064

b. рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯрд╡рд░реНрдХ

  • рдШрдЯрдХ:

  • рдкрд╣рд▓реА рд▓реАрдирд┐рдпрд░ рд▓реЗрдпрд░: nn.Linear(emb_dim, 4 * emb_dim)

  • рджреВрд╕рд░реА рд▓реАрдирд┐рдпрд░ рд▓реЗрдпрд░: nn.Linear(4 * emb_dim, emb_dim)

  • рдЧрдгрдирд╛рдПрдБ:

  • рдкрд╣рд▓реА рд▓реАрдирд┐рдпрд░ рд▓реЗрдпрд░:

ff_first_layer_params = (emb_dim * 4 * emb_dim) + (4 * emb_dim)
ff_first_layer_params = (768 * 3072) + 3072 = 2,359,296 + 3,072 = 2,362,368
  • рджреВрд╕рд░реА рд▓реАрдирд┐рдпрд░ рд▓реЗрдпрд░:

ff_second_layer_params = (4 * emb_dim * emb_dim) + emb_dim
ff_second_layer_params = (3072 * 768) + 768 = 2,359,296 + 768 = 2,360,064
  • рдХреБрд▓ рдлреАрдбрдлреЙрд░рд╡рд░реНрдб рдкреИрд░рд╛рдореАрдЯрд░:

ff_params = ff_first_layer_params + ff_second_layer_params
ff_params = 2,362,368 + 2,360,064 = 4,722,432

c. рд▓реЗрдпрд░ рдиреЙрд░реНрдорд▓рд╛рдЗрдЬреЗрд╢рди

  • рдШрдЯрдХ:

  • рдкреНрд░рддреНрдпреЗрдХ рдмреНрд▓реЙрдХ рдореЗрдВ рджреЛ LayerNorm рдЙрджрд╛рд╣рд░рдгред

  • рдкреНрд░рддреНрдпреЗрдХ LayerNorm рдореЗрдВ 2 * emb_dim рдкреИрд░рд╛рдореАрдЯрд░ рд╣реЛрддреЗ рд╣реИрдВ (рд╕реНрдХреЗрд▓ рдФрд░ рд╢рд┐рдлреНрдЯ)ред

  • рдЧрдгрдирд╛рдПрдБ:

layer_norm_params_per_block = 2 * (2 * emb_dim) = 2 * 768 * 2 = 3,072

d. рдкреНрд░рддреНрдпреЗрдХ рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХ рдХреЗ рд▓рд┐рдП рдХреБрд▓ рдкреИрд░рд╛рдореАрдЯрд░

pythonCopy codeparams_per_block = mha_params + ff_params + layer_norm_params_per_block
params_per_block = 2,360,064 + 4,722,432 + 3,072 = 7,085,568

рд╕рднреА рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдмреНрд▓реЙрдХреНрд╕ рдХреЗ рд▓рд┐рдП рдХреБрд▓ рдкреИрд░рд╛рдореАрдЯрд░

pythonCopy codetotal_transformer_blocks_params = params_per_block * n_layers
total_transformer_blocks_params = 7,085,568 * 12 = 85,026,816

3. рдЕрдВрддрд┐рдо рдкрд░рддреЗрдВ

рдХ. рдЕрдВрддрд┐рдо рдкрд░рдд рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг

  • рдкреИрд░рд╛рдореАрдЯрд░: 2 * emb_dim (рд╕реНрдХреЗрд▓ рдФрд░ рд╢рд┐рдлреНрдЯ)

pythonCopy codefinal_layer_norm_params = 2 * 768 = 1,536

b. рдЖрдЙрдЯрдкреБрдЯ рдкреНрд░реЛрдЬреЗрдХреНрд╢рди рд▓реЗрдпрд░ (out_head)

  • рд▓реЗрдпрд░: nn.Linear(emb_dim, vocab_size, bias=False)

  • рдкреИрд░рд╛рдореАрдЯрд░реНрд╕: emb_dim * vocab_size

pythonCopy codeoutput_projection_params = 768 * 50257 = 38,597,376

4. рд╕рднреА рдкреИрд░рд╛рдореАрдЯрд░ рдХрд╛ рд╕рд╛рд░рд╛рдВрд╢

pythonCopy codetotal_params = (
embedding_params +
total_transformer_blocks_params +
final_layer_norm_params +
output_projection_params
)
total_params = (
39,383,808 +
85,026,816 +
1,536 +
38,597,376
)
total_params = 163,009,536

Generate Text

рдПрдХ рдРрд╕рд╛ рдореЙрдбрд▓ рд╣реЛрдирд╛ рдЬреЛ рдЕрдЧрд▓реЗ рдЯреЛрдХрди рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рддрд╛ рд╣реИ рдЬреИрд╕реЗ рдХрд┐ рдкрд╣рд▓реЗ рд╡рд╛рд▓рд╛, рдмрд╕ рдЕрдВрддрд┐рдо рдЯреЛрдХрди рдорд╛рдиреЛрдВ рдХреЛ рдЖрдЙрдЯрдкреБрдЯ рд╕реЗ рд▓реЗрдирд╛ рдЖрд╡рд╢реНрдпрдХ рд╣реИ (рдХреНрдпреЛрдВрдХрд┐ рд╡реЗ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд┐рдП рдЧрдП рдЯреЛрдХрди рдХреЗ рд╣реЛрдВрдЧреЗ), рдЬреЛ рдХрд┐ рд╢рдмреНрджрд╛рд╡рд▓реА рдореЗрдВ рдкреНрд░рддреНрдпреЗрдХ рдкреНрд░рд╡рд┐рд╖реНрдЯрд┐ рдХреЗ рд▓рд┐рдП рдПрдХ рдорд╛рди рд╣реЛрдЧрд╛ рдФрд░ рдлрд┐рд░ softmax рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЖрдпрд╛рдореЛрдВ рдХреЛ 1 рдХреЗ рдпреЛрдЧ рдореЗрдВ рд╕рдВрднрд╛рд╡рдирд╛рдУрдВ рдореЗрдВ рд╕рд╛рдорд╛рдиреНрдпреАрдХреГрдд рдХрд░рдирд╛ рд╣реЛрдЧрд╛ рдФрд░ рдлрд┐рд░ рд╕рдмрд╕реЗ рдмрдбрд╝реЗ рдкреНрд░рд╡рд┐рд╖реНрдЯрд┐ рдХрд╛ рдЕрдиреБрдХреНрд░рдорд╛рдВрдХ рдкреНрд░рд╛рдкреНрдд рдХрд░рдирд╛ рд╣реЛрдЧрд╛, рдЬреЛ рдХрд┐ рд╢рдмреНрджрд╛рд╡рд▓реА рдХреЗ рднреАрддрд░ рд╢рдмреНрдж рдХрд╛ рдЕрдиреБрдХреНрд░рдорд╛рдВрдХ рд╣реЛрдЧрд╛ред

Code from https://github.com/rasbt/LLMs-from-scratch/blob/main/ch04/01_main-chapter-code/ch04.ipynb:

def generate_text_simple(model, idx, max_new_tokens, context_size):
# idx is (batch, n_tokens) array of indices in the current context
for _ in range(max_new_tokens):

# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]

# Get the predictions
with torch.no_grad():
logits = model(idx_cond)

# Focus only on the last time step
# (batch, n_tokens, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]

# Apply softmax to get probabilities
probas = torch.softmax(logits, dim=-1)  # (batch, vocab_size)

# Get the idx of the vocab entry with the highest probability value
idx_next = torch.argmax(probas, dim=-1, keepdim=True)  # (batch, 1)

# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)

return idx


start_context = "Hello, I am"

encoded = tokenizer.encode(start_context)
print("encoded:", encoded)

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print("encoded_tensor.shape:", encoded_tensor.shape)

model.eval() # disable dropout

out = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=6,
context_size=GPT_CONFIG_124M["context_length"]
)

print("Output:", out)
print("Output length:", len(out[0]))

рд╕рдВрджрд░реНрдн

Last updated