(microgpt の学習ループ)= # microgpt の学習ループ ## この章で学ぶこと - トークン列での損失、逆伝播、Adam 更新 - {numref}`microgpt の構造` **B6** に対応する (1)〜(7) の読み方 ## 学習ループ 学習では、現在のトークンから次トークンを予測し、そのズレを損失として計算します。損失から逆伝播で得た勾配に基づき、パラメータを更新するのに **Adam** という最適化手法を使います(**Adam とは何か**は直下の注記)。まず **Adam 用のバッファ宣言からログまで**(`microgpt.py` **L193〜L233**)を引用します。続けて **ブロック図と各ブロックの一行概要**を置き、そのあと **項 (1)~(7)** で論理ブロックごとに読みます。各ブロックのあとに**同じ範囲の部分抜粋**を重複して示します。数値例は **`gpt`(モデル本体)**の節と同様、理解用の仮定です。 :::{note} **Adam(Adaptive Moment Estimation)とは** ニューラルネットでは、損失を小さくするために**各パラメータ(重み)の値を少しずつ変える**必要があります。いちばん素朴なのは**勾配降下**で、各パラメータから「そのパラメータに関する損失の勾配」に比例して引きます。ところがパラメータごとに勾配の大きさや向きが違うため、**どのパラメータをどれだけ動かすか**をうまく調整したい場面が多く、いくつかの**適応的**な手法が提案されています。 **Adam**は、その代表のひとつで、各パラメータについて **勾配そのものの指数移動平均(第 1 モーメント、コードでは `m`)**と、**勾配の二乗の指数移動平均(第 2 モーメント、コードでは `v`)**を蓄え、それらから**パラメータごとの更新量のスケール**を決めます。学習のはじめで `m` や `v` が小さすぎる偏りを直す**バイアス補正**(`m_hat` / `v_hat`)も、標準的な Adam の式に含まれます。 本コードの **`learning_rate`**(全体の学習率)、**`beta1`** / **`beta2`**(モーメントの減衰率)、**`eps_adam`**(分母の安定化)は Adam のハイパーパラメータです。**(1)** で `m` / `v` のリストを用意し、**(6)** で各 `params[i]` に対して Adam の更新式を適用しています。 | ハイパーパラメータ | 既定値 | 意味 | |----------------|--------|-----| | `learning_rate` | 0.01 | 全体の学習率 | | `beta1` | 0.85 | 第 1 モーメント(勾配)の減衰率 | | `beta2` | 0.99 | 第 2 モーメント(勾配²)の減衰率 | | `eps_adam` | 1e-8 | 分母安定化用の小定数 | ::: ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 193-233 :lineno-match: ``` ## 学習ループ内のブロック図と概要 次の図は、上記ソースの制御の流れを **(1)~(7)** の項番号に対応させたものです。**(1)** は `for step` の外側で一度だけ実行され、**(2)~(7)** は各学習ステップで繰り返されます。**(4)** の中に **`pos_id`** による内側ループがあります。 :::: {container} mermaid-flow-half ```{mermaid} flowchart TB B1["(1) Adam ハイパーパラメータと m・v バッファの宣言"] B2["(2) for step in range(num_steps)"] B3["(3) 文書取得・トークン列・予測位置数 n"] B4["(4) K/V 初期化・各 pos で gpt→損失積み上げ・平均 loss"] B5["(5) loss.backward"] B6["(6) 線形学習率減衰と Adam による params 更新"] B7["(7) step と loss のログ出力"] B1 --> B2 --> B3 --> B4 --> B5 --> B6 --> B7 B7 -->|次の step| B2 ``` ::::

図-1: 学習ループのブロックと項 (1)~(7) の対応

| 項 | おおまかな役割 | |---|----------------| | (1) | `learning_rate` / `beta1` / `beta2` / `eps_adam` の設定と、パラメータ数ぶんの第 1・第 2 モーメント `m`・`v` の初期化。 | | (2) | `num_steps` 回の外側ループ。1 周で 1 文書分の順伝播・逆伝播・更新がまとまる。 | | (3) | `docs` から文書を選び BOS 付き `tokens` を作り、`n` で当ステップの予測位置数を決める。 | | (4) | 各層の `keys`/`values` を空にし、各 `pos_id` で `gpt`・`softmax` と正解トークンでの負の対数尤度を積み、平均 `loss` を得る。 | | (5) | `loss.backward()` で全パラメータの `grad` を計算する。 | | (6) | `lr_t` を線形に下げつつ、各 `params[i]` に Adam 更新を適用し、`grad` をリセットする。 | | (7) | 人間向けに `step+1` と `loss.data` を表示する。 | ## (1) L194〜L196 — Adam 用ハイパーパラメータとモーメントのバッファ ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 194-196 :lineno-match: ``` - **L194** で学習率 `learning_rate` と Adam の **`beta1`(第 1 モーメントの減衰)**、**`beta2`(第 2 モーメントの減衰)**、数値安定用の **`eps_adam`** を置きます。 - **L195〜L196** で、パラメータ 1 個につき **第 1 モーメント `m[i]`** と **第 2 モーメント `v[i]`** を用意し、初期値はいずれも **0** です。リストの長さは **`len(params)`**(学習対象の `Value` の個数)です。 ## (2) L201〜L202 — 学習ステップ数と外側のループ ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 201-202 :lineno-match: ``` **`num_steps`** 回、**`step = 0 … num_steps-1`** でループします。1 周のなかで「文書を取る → 損失 → 逆伝播 → Adam」がまとめて 1 **ステップ**です。 ## (3) L205〜L207 — 文書の取得・トークン列・予測する位置数 `n` ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 205-207 :lineno-match: ``` - **L205** で **`docs[step % len(docs)]`** により、ステップごとに 1 文書を選びます(文書が何件あっても必ずどれかが選ばれる)。 - **L206** で文頭・文末に **BOS** を付けたトークン列 **`tokens`** を作ります(実装は上の `literalinclude` のとおり)。 - **L207** で **`n = min(block_size, len(tokens) - 1)`** とし、**一度の順伝播で何文字分「次トークン」を予測するか**の上限を決めます(コンテキスト長 `block_size` を超えない)。 ## (4) L210〜L218 — K/V キャッシュの初期化・位置ループ・平均損失 ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 210-218 :lineno-match: ``` - **L210** で各層の **`keys` / `values`** を空リストのリストで初期化します(この文書の順伝播を始める前にキャッシュを空にする)。 - **L212〜L217** で **`pos_id`** ごとに、**現在トークン** `tokens[pos_id]` と **正解の次トークン** `tokens[pos_id+1]` を取り、**`gpt` → `softmax` → 正解ラベルでの負の対数尤度**を `losses` に追加します。 - **L218** で **`n` 個の損失の平均**をスカラー `loss` にします(1 文書あたりの目的関数)。 ## (5) L221 — 逆伝播 ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 221-221 :lineno-match: ``` **`loss.backward()`** により、**`loss`** に関与した計算グラフを逆に辿り、**各パラメータ `p` の `p.grad`** に勾配が入ります。 ## (6) L224〜L231 — 線形学習率減衰と Adam 更新 ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 224-231 :lineno-match: ``` - **L224** で **`lr_t = learning_rate * (1 - step / num_steps)`** とし、ステップが進むほど学習率を線形に下げます。 - **L225〜L231** で各 **`params[i]`** について **`m[i]`**, **`v[i]`** を更新し、**バイアス補正**付きの **`m_hat`**, **`v_hat`** から **`p.data`** を減算します(Adam の標準形)。 - **L231** で **`p.grad = 0`** とし、次ステップのために勾配を消します。 ## (7) L233 — ログ出力 ```{literalinclude} ../../../src/part3/microgpt.py :language: python :lines: 233-233 :lineno-match: ``` **現在のステップ番号**(人間向けに **`step+1`**)と、**`loss.data`**(スカラーに化した平均損失)を表示します(例: `step 1 / 1000 | loss 2.3456`)。 要約の流れは次のとおりです。 ```text 順伝播 ↓ 損失計算 ↓ 逆伝播 ↓ 重み更新 ``` ```{mermaid} sequenceDiagram autonumber participant W as 重み state_dict participant F as gpt 順伝播 participant S as 損失 participant B as backward participant A as Adam loop 学習ステップ W->>F: パラメータを読む F->>S: logits と正解から loss S->>B: loss.backward B->>W: 各 Value へ勾配伝播 A->>W: 重みを更新 end ```

図-2: 学習ループのメッセージの流れ

深層学習の基本が、そのまま小さい形で入っています。 上のブロック図は制御の流れ、下のシーケンス図はデータの往復の補足です。