4. Attention Mechanisms

Attention Mechanisms and Self-Attention in Neural Networks

Attention mechanisms allow neural networks to focus on specific parts of the input when generating each part of the output. それらは異なる入力に異なる重みを割り当て、モデルが現在のタスクに最も関連する入力を決定するのを助けます。これは、正確な翻訳のために文全体の文脈を理解する必要がある機械翻訳のようなタスクでは重要です。

この第4段階の目標は非常にシンプルです: いくつかの注意メカニズムを適用すること。これらは、語彙内の単語と現在の文の隣接単語との関係を捉えるための多くの繰り返し層になります。 これには多くの層が使用されるため、多くの学習可能なパラメータがこの情報を捉えることになります。

Understanding Attention Mechanisms

従来のシーケンスからシーケンスへのモデルは、入力シーケンスを固定サイズのコンテキストベクトルにエンコードします。しかし、このアプローチは長い文に対しては苦労します。なぜなら、固定サイズのコンテキストベクトルは必要なすべての情報を捉えられない可能性があるからです。注意メカニズムは、モデルが各出力トークンを生成する際にすべての入力トークンを考慮できるようにすることで、この制限に対処します。

Example: Machine Translation

ドイツ語の文「Kannst du mir helfen diesen Satz zu übersetzen」を英語に翻訳することを考えてみましょう。単語ごとの翻訳では、言語間の文法構造の違いにより、文法的に正しい英語の文は生成されません。注意メカニズムは、出力文の各単語を生成する際に入力文の関連部分に焦点を当てることを可能にし、より正確で一貫した翻訳を実現します。

Introduction to Self-Attention

自己注意(Self-attention)または内部注意(intra-attention)は、注意が単一のシーケンス内で適用され、そのシーケンスの表現を計算するメカニズムです。これにより、シーケンス内の各トークンが他のすべてのトークンに注意を向けることができ、モデルがトークン間の依存関係を距離に関係なく捉えるのを助けます。

Key Concepts

  • Tokens: 入力シーケンスの個々の要素(例: 文中の単語)。

  • Embeddings: トークンのベクトル表現で、意味情報を捉えます。

  • Attention Weights: 他のトークンに対する各トークンの重要性を決定する値。

Calculating Attention Weights: A Step-by-Step Example

"Hello shiny sun!" を考え、各単語を3次元の埋め込みで表現します:

  • Hello: [0.34, 0.22, 0.54]

  • shiny: [0.53, 0.34, 0.98]

  • sun: [0.29, 0.54, 0.93]

私たちの目標は、自己注意を使用して単語 "shiny"コンテキストベクトル を計算することです。

Step 1: Compute Attention Scores

各次元のクエリの値を関連するトークンの値と掛け算し、結果を加算します。トークンのペアごとに1つの値が得られます。

文中の各単語について、shiny に対する 注意スコア を、その埋め込みのドット積を計算することで求めます。

"Hello" と "shiny" の注意スコア

"shiny" と "shiny" の注意スコア

"sun" と "shiny" の注意スコア

Step 2: Normalize Attention Scores to Obtain Attention Weights

数学用語に迷わないでください。この関数の目標はシンプルです。すべての重みを正規化して、合計が1になるようにします

さらに、softmax 関数が使用されるのは、指数部分によって違いを強調し、有用な値を検出しやすくするためです。

注意スコアに softmax関数 を適用して、合計が1になる注意重みを得ます。

指数を計算します:

合計を計算します:

注意重みを計算します:

Step 3: Compute the Context Vector

各注意重みを関連するトークンの次元に掛け算し、すべての次元を合計して1つのベクトル(コンテキストベクトル)を得ます。

コンテキストベクトル は、すべての単語の埋め込みの重み付き合計として計算され、注意重みを使用します。

各成分を計算します:

  • "Hello" の重み付き埋め込み

* **"shiny" の重み付き埋め込み**:

* **"sun" の重み付き埋め込み**:

重み付き埋め込みを合計します:

context vector=[0.0779+0.2156+0.1057, 0.0504+0.1382+0.1972, 0.1237+0.3983+0.3390]=[0.3992,0.3858,0.8610]

このコンテキストベクトルは、文中のすべての単語からの情報を取り入れた「shiny」の強化された埋め込みを表します。

Summary of the Process

  1. 注意スコアを計算する: ターゲット単語の埋め込みとシーケンス内のすべての単語の埋め込みとの間のドット積を使用します。

  2. スコアを正規化して注意重みを得る: 注意スコアにsoftmax関数を適用して、合計が1になる重みを得ます。

  3. コンテキストベクトルを計算する: 各単語の埋め込みをその注意重みで掛け算し、結果を合計します。

Self-Attention with Trainable Weights

実際には、自己注意メカニズムは学習可能な重みを使用して、クエリ、キー、および値の最適な表現を学習します。これには、3つの重み行列を導入します:

クエリは以前と同様に使用するデータであり、キーと値の行列は単にランダムに学習可能な行列です。

Step 1: Compute Queries, Keys, and Values

各トークンは、定義された行列でその次元値を掛け算することによって、独自のクエリ、キー、および値の行列を持ちます:

これらの行列は、元の埋め込みを注意を計算するのに適した新しい空間に変換します。

次のように仮定します:

  • 入力次元 din=3(埋め込みサイズ)

  • 出力次元 dout=2(クエリ、キー、および値のための希望する次元)

重み行列を初期化します:

import torch.nn as nn

d_in = 3
d_out = 2

W_query = nn.Parameter(torch.rand(d_in, d_out))
W_key = nn.Parameter(torch.rand(d_in, d_out))
W_value = nn.Parameter(torch.rand(d_in, d_out))

クエリ、キー、値を計算する:

queries = torch.matmul(inputs, W_query)
keys = torch.matmul(inputs, W_key)
values = torch.matmul(inputs, W_value)

ステップ 2: スケーリングされたドット積アテンションの計算

アテンションスコアの計算

以前の例と似ていますが、今回はトークンの次元の値を使用するのではなく、トークンのキー行列を使用します(すでに次元を使用して計算されています)。したがって、各クエリ qi​ とキー kj​ に対して:

スコアのスケーリング

ドット積が大きくなりすぎないように、キー次元 dk​ の平方根でスケーリングします:

スコアは次元の平方根で割られます。なぜなら、ドット積が非常に大きくなる可能性があり、これがそれらを調整するのに役立つからです。

アテンションウェイトを得るためにソフトマックスを適用: 初期の例と同様に、すべての値を正規化して合計が1になるようにします。

ステップ 3: コンテキストベクトルの計算

初期の例と同様に、すべての値行列をそのアテンションウェイトで掛けて合計します:

コード例

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb から例を取得すると、私たちが話した自己注意機能を実装するこのクラスを確認できます:

import torch

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

import torch.nn as nn
class SelfAttention_v2(nn.Module):

def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

context_vec = attn_weights @ values
return context_vec

d_in=3
d_out=2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

注意:行列をランダムな値で初期化する代わりに、nn.Linearを使用してすべての重みをトレーニングするパラメータとしてマークします。

因果注意:未来の単語を隠す

LLMでは、モデルが現在の位置の前に出現するトークンのみを考慮して次のトークンを予測することを望みます。因果注意、またはマスク付き注意は、注意メカニズムを変更して未来のトークンへのアクセスを防ぐことによってこれを実現します。

因果注意マスクの適用

因果注意を実装するために、ソフトマックス操作の前に注意スコアにマスクを適用します。これにより、残りのスコアは合計1になります。このマスクは、未来のトークンの注意スコアを負の無限大に設定し、ソフトマックス後にその注意重みがゼロになることを保証します。

手順

  1. 注意スコアの計算:以前と同様。

  2. マスクの適用:対角線の上に負の無限大で満たされた上三角行列を使用します。

mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
masked_scores = attention_scores + mask
  1. ソフトマックスの適用:マスクされたスコアを使用して注意重みを計算します。

attention_weights = torch.softmax(masked_scores, dim=-1)

ドロップアウトによる追加の注意重みのマスキング

過学習を防ぐために、ソフトマックス操作の後に注意重みにドロップアウトを適用できます。ドロップアウトは、トレーニング中に注意重みの一部をランダムにゼロにします

dropout = nn.Dropout(p=0.5)
attention_weights = dropout(attention_weights)

通常のドロップアウトは約10-20%です。

Code Example

Code example from https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb:

import torch
import torch.nn as nn

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

class CausalAttention(nn.Module):

def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # This generates the keys of the tokens
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.transpose(1, 2) # Moves the third dimension to the second one and the second one to the third one to be able to multiply
attn_scores.masked_fill_(  # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)

context_vec = attn_weights @ values
return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

シングルヘッドアテンションからマルチヘッドアテンションへの拡張

マルチヘッドアテンションは、実際には複数のインスタンスの自己アテンション関数を実行し、それぞれが独自の重みを持つことで、異なる最終ベクトルが計算されることを意味します。

コード例

前のコードを再利用し、ラッパーを追加して何度も実行することも可能ですが、これはhttps://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynbからの最適化されたバージョンで、すべてのヘッドを同時に処理します(高価なforループの数を減らします)。コードに示されているように、各トークンの次元はヘッドの数に応じて異なる次元に分割されます。このように、トークンが8次元を持ち、3つのヘッドを使用したい場合、次元は4次元の2つの配列に分割され、各ヘッドはそのうちの1つを使用します:

class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

別のコンパクトで効率的な実装のために、PyTorchのtorch.nn.MultiheadAttentionクラスを使用することができます。

ChatGPTによる、トークンの次元をヘッド間で分割する方が、各ヘッドがすべてのトークンのすべての次元をチェックするよりも良い理由の短い回答:

各ヘッドがすべての埋め込み次元を処理できるようにすることは、各ヘッドが完全な情報にアクセスできるため有利に思えるかもしれませんが、標準的な実践は埋め込み次元をヘッド間で分割することです。このアプローチは、計算効率とモデルのパフォーマンスのバランスを取り、各ヘッドが多様な表現を学ぶことを促します。したがって、埋め込み次元を分割することは、一般的に各ヘッドがすべての次元をチェックするよりも好まれます。

References

Last updated