(microgpt のモデル本体(gpt))= # microgpt のモデル本体(gpt) ## この章で学ぶこと - `gpt` の埋め込みループ、Attention、MLP、`lm_head` までのデータの流れ - {numref}`microgpt の構造` **B5** に対応する (1)~(7) の行ブロック読み ## `gpt`(モデル本体) 埋め込みループ(Attention と MLP)、最後の `lm_head` までをひとまとまりにしたのが `gpt` です。上の `linear` / `softmax` / `rmsnorm` が、この関数の中で繰り返し使われます。まず **`gpt` 関数全体**(`microgpt.py` **L146~L189**)を引用します。続けて **ブロック図と各ブロックの一行概要**を置き、そのあと **項 (1)~(7)** として行ブロックごとに説明します(各項の直後に**同じ行範囲**を重複して示します)。 ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 146-189 :lineno-match: ``` ## `gpt` 内のブロック図と概要 次の図は、上記ソースの制御の流れを **(1)~(7)** の項番号に対応させたものです。`n_layer > 1` のときは、**(3)~(6)** のかたまりをインデックス `li` で繰り返します。 :::: {container} mermaid-flow-half ```{mermaid} flowchart TB IN["token_id , pos_id , keys , values"] B1["(1) wte / wpe で埋め込み結合 → 最初の rmsnorm"] B2["(2) for li in range(n_layer)"] subgraph layer["各層 li(Attention → MLP)"] B3["(3) 残差退避・rmsnorm・Q/K/V・keys/values へ追記"] B4["(4) ヘッドループ: スケール付き内積 → softmax → V の重み付き和"] B5["(5) attn_wo でヘッド結合 → 残差加算"] B6["(6) MLP: 残差・rmsnorm・fc1 → ReLU → fc2 → 残差"] end B7["(7) lm_head → logits を return"] IN --> B1 --> B2 --> B3 --> B4 --> B5 --> B6 B6 -->|次の li があれば| B2 B6 -->|すべての層が終了後| B7 ``` ::::
図-1: `gpt` 内のブロックと項 (1)~(7) の対応
| 項 | おおまかな役割 | |---|----------------| | (1) | 引数の意味(docstring)と、`wte` / `wpe` から埋め込みベクトルを作り、ループに入る前の `rmsnorm` まで。 | | (2) | `n_layer` 回、Attention ブロックと MLP ブロックの組を繰り返す `for` 。 | | (3) | Attention の入口で残差用に `x` を退避し、正規化後に Q/K/V を計算し、過去位置分の `k` / `v` をキャッシュに追加。 | | (4) | ヘッドごとに、過去の K から注意重みを `softmax` し、V を重み付き和して `x_attn` に連結。 | | (5) | `attn_wo` でヘッド出力を埋め込み幅に写し、Attention 前の残差を加算。 | | (6) | MLP でも同様に残差を取り、`rmsnorm` のあと、`mlp_fc1` → ReLU → `mlp_fc2` → 残差。 | | (7) | 層ループ後の `x` を `lm_head` で語彙サイズの logits に写像して返す。 | ## (1) L146〜L154 — 関数定義・ドキュストリングと入力の準備 ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 146-154 :lineno-match: ``` - **L146** は関数定義行(引数は現在のトークン ID・位置 ID、各層の K/V リスト)です。 - **L147〜L150** は docstring で、順伝播の役割と、`keys` / `values` が層ごとに「過去の K/V を溜めるキャッシュ」であることを説明しています。 - **L151〜L154** で、トークン ID に対応する行を `wte` から、位置 ID に対応する行を `wpe` から取り出し、成分ごとに足してひとつの埋め込みベクトル `x` にします。続けて **最初の `rmsnorm`** をかけ、以降のブロックに渡すスケールにそろえます。 ## (2) L156 — 層を繰り返す ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 156-156 :lineno-match: ``` **`for li in range(n_layer):`** で、下記の Attention ブロックと MLP ブロックの組を **`n_layer` 回**繰り返します。`li` は 0 から始まる層インデックスです。 ## (3) L157〜L164 — Attention ブロックの前半(残差・正規化・Q/K/V とキャッシュ) ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 157-164 :lineno-match: ``` - **L157** はコメント行です。 - **L158** でこのサブブロックに入る直前の `x` を `x_residual` に退避し、あとで残差接続に使います。 - **L159** で `rmsnorm` し、**L160〜L162** で同じ正規化済み `x` から Query・Key・Value をそれぞれ別の重み行列で線形変換します。 - **L163〜L164** で、計算した `k` と `v` を、その層 `li` 用のリスト `keys[li]` / `values[li]` の末尾に追加します。学習ループでは位置が時間順に進むので、ここに「これまでの位置の K/V」が溜まり、因果的な注意(過去だけ見る)に使えます。 ## (4) L165〜L176 — ヘッドごとのアテンション(スケール付き内積・softmax・V への重み付き和) ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 165-176 :lineno-match: ``` - **L165** で、ヘッドの出力を足すリスト `x_attn` を空で用意します。 - **L166** でヘッド番号 `h` のループに入ります。 - **L167〜L170** では、埋め込み全体を `head_dim` 幅に分割したうえで、**現在位置の** `q` のスライス `q_h` と、**これまでに `keys[li]` / `values[li]` に蓄積された各時刻の** K/V の同じヘッド部分 `k_h` / `v_h` を取り出します(過去の長さぶんのベクトル列になります)。 - **L172〜L173** では、現在の `q_h` と各過去時刻の `k_h` との**スケール付き内積**をとって `attn_logits` にし、**`softmax`** で「どの過去位置をどれだけ見るか」の重み `attn_weights` にします。 - **L175〜L176** では、その重みで各時刻の `v_h` を足し合わせ、ヘッド `h` の出力ベクトル `head_out` を求め、`x_attn` に連結して足していきます。 ## (5) L177〜L178 — Attention ブロックの後半(ヘッド結合と残差) ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 177-178 :lineno-match: ``` - **L177** で、すべてのヘッドからつながった `x_attn` を `attn_wo` で線形変換し、埋め込み次元にまとめます。 - **L178** で、入り口で退避した `x_residual` に足し合わせ、Attention サブブロックを通した表現に更新します。 ## (6) L180〜L186 — MLP ブロック ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 180-186 :lineno-match: ``` - **L180** はコメント行です。 - **L181** で再度 `x` を退避し、 - **L182** で `rmsnorm`、 - **L183〜L185** で `mlp_fc1` → **ReLU** → `mlp_fc2` という二つの線形と非線形を通し、 - **L186** で残差を加算してこの層の出力を確定します。 :::{note} **ReLU(Rectified Linear Unit)とは** 各成分で「値が負なら 0、そうでなければそのまま」にする活性化です。実装では `xi.relu()` と書いており、`Value` のスカラーに対して `max(0, x)` に相当します。線形変換だけを重ねると、まとめてひとつの線形変換と同じになってしまいますが、そのあいだに ReLU を挟むと**非線形**になり、表現力が増します。本コードでは GPT-2 系でよく使う GeLU の代わりに ReLU を採用しており、モデル部コメントにも「GeLU→ReLU」とあるとおり、数値互換より**単純さ**を優先した差分です。 ::: ## (7) L188〜L189 — 語彙への写像と返り値 ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 188-189 :lineno-match: ``` - **L188〜L189:** 層ループを抜けたあとの `x` に対し、`lm_head` をかけて語彙サイズ長の **logits** にして返します(行 `i` と `x` の内積が logit `i`)。