22. microgpt のモデル本体(gpt)
22.1. この章で学ぶこと
gptの埋め込みループ、Attention、MLP、lm_headまでのデータの流れ18 章 B5 に対応する (1)~(7) の行ブロック読み
22.2. gpt(モデル本体)
埋め込みループ(Attention と MLP)、最後の lm_head までをひとまとまりにしたのが gpt です。上の linear / softmax / rmsnorm が、この関数の中で繰り返し使われます。まず gpt 関数全体(microgpt.py L146~L189)を引用します。続けて ブロック図と各ブロックの一行概要を置き、そのあと 項 (1)~(7) として行ブロックごとに説明します(各項の直後に同じ行範囲を重複して示します)。
146def gpt(token_id, pos_id, keys, values):
147 """
148 GPTの順伝播。現在のトークンと位置から、次トークンのlogitsを出力。
149 keys, values: 過去のK/Vをキャッシュ(推論時の効率化、学習時は因果マスク用)
150 """
151 tok_emb = state_dict['wte'][token_id] # トークン埋め込み
152 pos_emb = state_dict['wpe'][pos_id] # 位置埋め込み
153 x = [t + p for t, p in zip(tok_emb, pos_emb)] # トークン+位置の結合埋め込み
154 x = rmsnorm(x) # 初期正規化(残差接続経由で勾配が流れるため冗長ではない)
155
156 for li in range(n_layer):
157 # --- 1) マルチヘッドアテンションブロック ---
158 x_residual = x
159 x = rmsnorm(x)
160 q = linear(x, state_dict[f'layer{li}.attn_wq']) # Query
161 k = linear(x, state_dict[f'layer{li}.attn_wk']) # Key
162 v = linear(x, state_dict[f'layer{li}.attn_wv']) # Value
163 keys[li].append(k)
164 values[li].append(v)
165 x_attn = []
166 for h in range(n_head):
167 hs = h * head_dim
168 q_h = q[hs:hs+head_dim]
169 k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
170 v_h = [vi[hs:hs+head_dim] for vi in values[li]]
171 # スケール付き内積: attn = softmax(QK^T / sqrt(d_k))
172 attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
173 attn_weights = softmax(attn_logits)
174 # 重み付き和: output = attn @ V
175 head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
176 x_attn.extend(head_out)
177 x = linear(x_attn, state_dict[f'layer{li}.attn_wo']) # ヘッド結合
178 x = [a + b for a, b in zip(x, x_residual)] # 残差接続
179
180 # --- 2) MLPブロック ---
181 x_residual = x
182 x = rmsnorm(x)
183 x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
184 x = [xi.relu() for xi in x]
185 x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
186 x = [a + b for a, b in zip(x, x_residual)] # 残差接続
187
188 logits = linear(x, state_dict['lm_head']) # 語彙サイズ次元のlogits
189 return logits
22.3. gpt 内のブロック図と概要
次の図は、上記ソースの制御の流れを (1)~(7) の項番号に対応させたものです。n_layer > 1 のときは、(3)~(6) のかたまりをインデックス li で繰り返します。
図-1: `gpt` 内のブロックと項 (1)~(7) の対応
項 |
おおまかな役割 |
|---|---|
(1) |
引数の意味(docstring)と、 |
(2) |
|
(3) |
Attention の入口で残差用に |
(4) |
ヘッドごとに、過去の K から注意重みを |
(5) |
|
(6) |
MLP でも同様に残差を取り、 |
(7) |
層ループ後の |
22.4. (1) L146〜L154 — 関数定義・ドキュストリングと入力の準備
146def gpt(token_id, pos_id, keys, values):
147 """
148 GPTの順伝播。現在のトークンと位置から、次トークンのlogitsを出力。
149 keys, values: 過去のK/Vをキャッシュ(推論時の効率化、学習時は因果マスク用)
150 """
151 tok_emb = state_dict['wte'][token_id] # トークン埋め込み
152 pos_emb = state_dict['wpe'][pos_id] # 位置埋め込み
153 x = [t + p for t, p in zip(tok_emb, pos_emb)] # トークン+位置の結合埋め込み
154 x = rmsnorm(x) # 初期正規化(残差接続経由で勾配が流れるため冗長ではない)
L146 は関数定義行(引数は現在のトークン ID・位置 ID、各層の K/V リスト)です。
L147〜L150 は docstring で、順伝播の役割と、
keys/valuesが層ごとに「過去の K/V を溜めるキャッシュ」であることを説明しています。L151〜L154 で、トークン ID に対応する行を
wteから、位置 ID に対応する行をwpeから取り出し、成分ごとに足してひとつの埋め込みベクトルxにします。続けて 最初のrmsnormをかけ、以降のブロックに渡すスケールにそろえます。
22.5. (2) L156 — 層を繰り返す
156 for li in range(n_layer):
for li in range(n_layer): で、下記の Attention ブロックと MLP ブロックの組を n_layer 回繰り返します。li は 0 から始まる層インデックスです。
22.6. (3) L157〜L164 — Attention ブロックの前半(残差・正規化・Q/K/V とキャッシュ)
157 # --- 1) マルチヘッドアテンションブロック ---
158 x_residual = x
159 x = rmsnorm(x)
160 q = linear(x, state_dict[f'layer{li}.attn_wq']) # Query
161 k = linear(x, state_dict[f'layer{li}.attn_wk']) # Key
162 v = linear(x, state_dict[f'layer{li}.attn_wv']) # Value
163 keys[li].append(k)
164 values[li].append(v)
L157 はコメント行です。
L158 でこのサブブロックに入る直前の
xをx_residualに退避し、あとで残差接続に使います。L159 で
rmsnormし、L160〜L162 で同じ正規化済みxから Query・Key・Value をそれぞれ別の重み行列で線形変換します。L163〜L164 で、計算した
kとvを、その層li用のリストkeys[li]/values[li]の末尾に追加します。学習ループでは位置が時間順に進むので、ここに「これまでの位置の K/V」が溜まり、因果的な注意(過去だけ見る)に使えます。
22.7. (4) L165〜L176 — ヘッドごとのアテンション(スケール付き内積・softmax・V への重み付き和)
165 x_attn = []
166 for h in range(n_head):
167 hs = h * head_dim
168 q_h = q[hs:hs+head_dim]
169 k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
170 v_h = [vi[hs:hs+head_dim] for vi in values[li]]
171 # スケール付き内積: attn = softmax(QK^T / sqrt(d_k))
172 attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
173 attn_weights = softmax(attn_logits)
174 # 重み付き和: output = attn @ V
175 head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
176 x_attn.extend(head_out)
L165 で、ヘッドの出力を足すリスト
x_attnを空で用意します。L166 でヘッド番号
hのループに入ります。L167〜L170 では、埋め込み全体を
head_dim幅に分割したうえで、現在位置のqのスライスq_hと、これまでにkeys[li]/values[li]に蓄積された各時刻の K/V の同じヘッド部分k_h/v_hを取り出します(過去の長さぶんのベクトル列になります)。L172〜L173 では、現在の
q_hと各過去時刻のk_hとのスケール付き内積をとってattn_logitsにし、softmaxで「どの過去位置をどれだけ見るか」の重みattn_weightsにします。L175〜L176 では、その重みで各時刻の
v_hを足し合わせ、ヘッドhの出力ベクトルhead_outを求め、x_attnに連結して足していきます。
22.8. (5) L177〜L178 — Attention ブロックの後半(ヘッド結合と残差)
177 x = linear(x_attn, state_dict[f'layer{li}.attn_wo']) # ヘッド結合
178 x = [a + b for a, b in zip(x, x_residual)] # 残差接続
L177 で、すべてのヘッドからつながった
x_attnをattn_woで線形変換し、埋め込み次元にまとめます。L178 で、入り口で退避した
x_residualに足し合わせ、Attention サブブロックを通した表現に更新します。
22.9. (6) L180〜L186 — MLP ブロック
180 # --- 2) MLPブロック ---
181 x_residual = x
182 x = rmsnorm(x)
183 x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
184 x = [xi.relu() for xi in x]
185 x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
186 x = [a + b for a, b in zip(x, x_residual)] # 残差接続
L180 はコメント行です。
L181 で再度
xを退避し、L182 で
rmsnorm、L183〜L185 で
mlp_fc1→ ReLU →mlp_fc2という二つの線形と非線形を通し、L186 で残差を加算してこの層の出力を確定します。
注釈
ReLU(Rectified Linear Unit)とは
各成分で「値が負なら 0、そうでなければそのまま」にする活性化です。実装では xi.relu() と書いており、Value のスカラーに対して max(0, x) に相当します。線形変換だけを重ねると、まとめてひとつの線形変換と同じになってしまいますが、そのあいだに ReLU を挟むと非線形になり、表現力が増します。本コードでは GPT-2 系でよく使う GeLU の代わりに ReLU を採用しており、モデル部コメントにも「GeLU→ReLU」とあるとおり、数値互換より単純さを優先した差分です。
22.10. (7) L188〜L189 — 語彙への写像と返り値
188 logits = linear(x, state_dict['lm_head']) # 語彙サイズ次元のlogits
189 return logits
L188〜L189: 層ループを抜けたあとの
xに対し、lm_headをかけて語彙サイズ長の logits にして返します(行iとxの内積が logiti)。