22. microgpt のモデル本体(gpt)

22.1. この章で学ぶこと

  • gpt の埋め込みループ、Attention、MLP、lm_head までのデータの流れ

  • 18 章 B5 に対応する (1)~(7) の行ブロック読み

22.2. gpt(モデル本体)

埋め込みループ(Attention と MLP)、最後の lm_head までをひとまとまりにしたのが gpt です。上の linear / softmax / rmsnorm が、この関数の中で繰り返し使われます。まず gpt 関数全体microgpt.py L146~L189)を引用します。続けて ブロック図と各ブロックの一行概要を置き、そのあと 項 (1)~(7) として行ブロックごとに説明します(各項の直後に同じ行範囲を重複して示します)。

146def gpt(token_id, pos_id, keys, values):
147    """
148    GPTの順伝播。現在のトークンと位置から、次トークンのlogitsを出力。
149    keys, values: 過去のK/Vをキャッシュ(推論時の効率化、学習時は因果マスク用)
150    """
151    tok_emb = state_dict['wte'][token_id]   # トークン埋め込み
152    pos_emb = state_dict['wpe'][pos_id]     # 位置埋め込み
153    x = [t + p for t, p in zip(tok_emb, pos_emb)]  # トークン+位置の結合埋め込み
154    x = rmsnorm(x)  # 初期正規化(残差接続経由で勾配が流れるため冗長ではない)
155
156    for li in range(n_layer):
157        # --- 1) マルチヘッドアテンションブロック ---
158        x_residual = x
159        x = rmsnorm(x)
160        q = linear(x, state_dict[f'layer{li}.attn_wq'])  # Query
161        k = linear(x, state_dict[f'layer{li}.attn_wk'])  # Key
162        v = linear(x, state_dict[f'layer{li}.attn_wv'])  # Value
163        keys[li].append(k)
164        values[li].append(v)
165        x_attn = []
166        for h in range(n_head):
167            hs = h * head_dim
168            q_h = q[hs:hs+head_dim]
169            k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
170            v_h = [vi[hs:hs+head_dim] for vi in values[li]]
171            # スケール付き内積: attn = softmax(QK^T / sqrt(d_k))
172            attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
173            attn_weights = softmax(attn_logits)
174            # 重み付き和: output = attn @ V
175            head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
176            x_attn.extend(head_out)
177        x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])  # ヘッド結合
178        x = [a + b for a, b in zip(x, x_residual)]  # 残差接続
179
180        # --- 2) MLPブロック ---
181        x_residual = x
182        x = rmsnorm(x)
183        x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
184        x = [xi.relu() for xi in x]
185        x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
186        x = [a + b for a, b in zip(x, x_residual)]  # 残差接続
187
188    logits = linear(x, state_dict['lm_head'])  # 語彙サイズ次元のlogits
189    return logits

22.3. gpt 内のブロック図と概要

次の図は、上記ソースの制御の流れを (1)~(7) の項番号に対応させたものです。n_layer > 1 のときは、(3)~(6) のかたまりをインデックス li で繰り返します。

diagram

図-1: `gpt` 内のブロックと項 (1)~(7) の対応

おおまかな役割

(1)

引数の意味(docstring)と、wte / wpe から埋め込みベクトルを作り、ループに入る前の rmsnorm まで。

(2)

n_layer 回、Attention ブロックと MLP ブロックの組を繰り返す for

(3)

Attention の入口で残差用に x を退避し、正規化後に Q/K/V を計算し、過去位置分の k / v をキャッシュに追加。

(4)

ヘッドごとに、過去の K から注意重みを softmax し、V を重み付き和して x_attn に連結。

(5)

attn_wo でヘッド出力を埋め込み幅に写し、Attention 前の残差を加算。

(6)

MLP でも同様に残差を取り、rmsnorm のあと、mlp_fc1 → ReLU → mlp_fc2 → 残差。

(7)

層ループ後の xlm_head で語彙サイズの logits に写像して返す。

22.4. (1) L146〜L154 — 関数定義・ドキュストリングと入力の準備

146def gpt(token_id, pos_id, keys, values):
147    """
148    GPTの順伝播。現在のトークンと位置から、次トークンのlogitsを出力。
149    keys, values: 過去のK/Vをキャッシュ(推論時の効率化、学習時は因果マスク用)
150    """
151    tok_emb = state_dict['wte'][token_id]   # トークン埋め込み
152    pos_emb = state_dict['wpe'][pos_id]     # 位置埋め込み
153    x = [t + p for t, p in zip(tok_emb, pos_emb)]  # トークン+位置の結合埋め込み
154    x = rmsnorm(x)  # 初期正規化(残差接続経由で勾配が流れるため冗長ではない)
  • L146 は関数定義行(引数は現在のトークン ID・位置 ID、各層の K/V リスト)です。

  • L147〜L150 は docstring で、順伝播の役割と、keys / values が層ごとに「過去の K/V を溜めるキャッシュ」であることを説明しています。

  • L151〜L154 で、トークン ID に対応する行を wte から、位置 ID に対応する行を wpe から取り出し、成分ごとに足してひとつの埋め込みベクトル x にします。続けて 最初の rmsnorm をかけ、以降のブロックに渡すスケールにそろえます。

22.5. (2) L156 — 層を繰り返す

156    for li in range(n_layer):

for li in range(n_layer): で、下記の Attention ブロックと MLP ブロックの組を n_layer繰り返します。li は 0 から始まる層インデックスです。

22.6. (3) L157〜L164 — Attention ブロックの前半(残差・正規化・Q/K/V とキャッシュ)

157        # --- 1) マルチヘッドアテンションブロック ---
158        x_residual = x
159        x = rmsnorm(x)
160        q = linear(x, state_dict[f'layer{li}.attn_wq'])  # Query
161        k = linear(x, state_dict[f'layer{li}.attn_wk'])  # Key
162        v = linear(x, state_dict[f'layer{li}.attn_wv'])  # Value
163        keys[li].append(k)
164        values[li].append(v)
  • L157 はコメント行です。

  • L158 でこのサブブロックに入る直前の xx_residual に退避し、あとで残差接続に使います。

  • L159rmsnorm し、L160〜L162 で同じ正規化済み x から Query・Key・Value をそれぞれ別の重み行列で線形変換します。

  • L163〜L164 で、計算した kv を、その層 li 用のリスト keys[li] / values[li] の末尾に追加します。学習ループでは位置が時間順に進むので、ここに「これまでの位置の K/V」が溜まり、因果的な注意(過去だけ見る)に使えます。

22.7. (4) L165〜L176 — ヘッドごとのアテンション(スケール付き内積・softmax・V への重み付き和)

165        x_attn = []
166        for h in range(n_head):
167            hs = h * head_dim
168            q_h = q[hs:hs+head_dim]
169            k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
170            v_h = [vi[hs:hs+head_dim] for vi in values[li]]
171            # スケール付き内積: attn = softmax(QK^T / sqrt(d_k))
172            attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
173            attn_weights = softmax(attn_logits)
174            # 重み付き和: output = attn @ V
175            head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
176            x_attn.extend(head_out)
  • L165 で、ヘッドの出力を足すリスト x_attn を空で用意します。

  • L166 でヘッド番号 h のループに入ります。

  • L167〜L170 では、埋め込み全体を head_dim 幅に分割したうえで、現在位置の q のスライス q_h と、これまでに keys[li] / values[li] に蓄積された各時刻の K/V の同じヘッド部分 k_h / v_h を取り出します(過去の長さぶんのベクトル列になります)。

  • L172〜L173 では、現在の q_h と各過去時刻の k_h とのスケール付き内積をとって attn_logits にし、softmax で「どの過去位置をどれだけ見るか」の重み attn_weights にします。

  • L175〜L176 では、その重みで各時刻の v_h を足し合わせ、ヘッド h の出力ベクトル head_out を求め、x_attn に連結して足していきます。

22.8. (5) L177〜L178 — Attention ブロックの後半(ヘッド結合と残差)

177        x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])  # ヘッド結合
178        x = [a + b for a, b in zip(x, x_residual)]  # 残差接続
  • L177 で、すべてのヘッドからつながった x_attnattn_wo で線形変換し、埋め込み次元にまとめます。

  • L178 で、入り口で退避した x_residual に足し合わせ、Attention サブブロックを通した表現に更新します。

22.9. (6) L180〜L186 — MLP ブロック

180        # --- 2) MLPブロック ---
181        x_residual = x
182        x = rmsnorm(x)
183        x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
184        x = [xi.relu() for xi in x]
185        x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
186        x = [a + b for a, b in zip(x, x_residual)]  # 残差接続
  • L180 はコメント行です。

  • L181 で再度 x を退避し、

  • L182rmsnorm

  • L183〜L185mlp_fc1ReLUmlp_fc2 という二つの線形と非線形を通し、

  • L186 で残差を加算してこの層の出力を確定します。

注釈

ReLU(Rectified Linear Unit)とは

各成分で「値が負なら 0、そうでなければそのまま」にする活性化です。実装では xi.relu() と書いており、Value のスカラーに対して max(0, x) に相当します。線形変換だけを重ねると、まとめてひとつの線形変換と同じになってしまいますが、そのあいだに ReLU を挟むと非線形になり、表現力が増します。本コードでは GPT-2 系でよく使う GeLU の代わりに ReLU を採用しており、モデル部コメントにも「GeLU→ReLU」とあるとおり、数値互換より単純さを優先した差分です。

22.10. (7) L188〜L189 — 語彙への写像と返り値

188    logits = linear(x, state_dict['lm_head'])  # 語彙サイズ次元のlogits
189    return logits
  • L188〜L189: 層ループを抜けたあとの x に対し、lm_head をかけて語彙サイズ長の logits にして返します(行 ix の内積が logit i)。