23. microgpt の学習ループ

23.1. この章で学ぶこと

  • トークン列での損失、逆伝播、Adam 更新

  • 18 章 B6 に対応する (1)〜(7) の読み方

23.2. 学習ループ

学習では、現在のトークンから次トークンを予測し、そのズレを損失として計算します。損失から逆伝播で得た勾配に基づき、パラメータを更新するのに Adam という最適化手法を使います(Adam とは何かは直下の注記)。まず Adam 用のバッファ宣言からログまでmicrogpt.py L193〜L233)を引用します。続けて ブロック図と各ブロックの一行概要を置き、そのあと 項 (1)~(7) で論理ブロックごとに読みます。各ブロックのあとに同じ範囲の部分抜粋を重複して示します。数値例は **gpt(モデル本体)**の節と同様、理解用の仮定です。

注釈

Adam(Adaptive Moment Estimation)とは

ニューラルネットでは、損失を小さくするために各パラメータ(重み)の値を少しずつ変える必要があります。いちばん素朴なのは勾配降下で、各パラメータから「そのパラメータに関する損失の勾配」に比例して引きます。ところがパラメータごとに勾配の大きさや向きが違うため、どのパラメータをどれだけ動かすかをうまく調整したい場面が多く、いくつかの適応的な手法が提案されています。

Adamは、その代表のひとつで、各パラメータについて 勾配そのものの指数移動平均(第 1 モーメント、コードでは mと、勾配の二乗の指数移動平均(第 2 モーメント、コードでは vを蓄え、それらからパラメータごとの更新量のスケールを決めます。学習のはじめで mv が小さすぎる偏りを直すバイアス補正m_hat / v_hat)も、標準的な Adam の式に含まれます。

本コードの learning_rate(全体の学習率)、beta1 / beta2(モーメントの減衰率)、eps_adam(分母の安定化)は Adam のハイパーパラメータです。(1)m / v のリストを用意し、(6) で各 params[i] に対して Adam の更新式を適用しています。

ハイパーパラメータ

既定値

意味

learning_rate

0.01

全体の学習率

beta1

0.85

第 1 モーメント(勾配)の減衰率

beta2

0.99

第 2 モーメント(勾配²)の減衰率

eps_adam

1e-8

分母安定化用の小定数

193# -----------------------------------------------------------------------------
194learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
195m = [0.0] * len(params)  # 第1モーメント(勾配の移動平均)
196v = [0.0] * len(params)  # 第2モーメント(勾配の二乗の移動平均)
197
198# -----------------------------------------------------------------------------
199# 学習ループ
200# -----------------------------------------------------------------------------
201num_steps = 1000  # 学習ステップ数
202for step in range(num_steps):
203
204    # 1文書を取得し、トークン化。前後にBOSを付与
205    doc = docs[step % len(docs)]
206    tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
207    n = min(block_size, len(tokens) - 1)  # 予測する位置数
208
209    # 順伝播: 各位置で次トークンを予測し、負の対数尤度(交差エントロピー)を計算
210    keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
211    losses = []
212    for pos_id in range(n):
213        token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
214        logits = gpt(token_id, pos_id, keys, values)
215        probs = softmax(logits)
216        loss_t = -probs[target_id].log()  # 正解トークンの負の対数尤度
217        losses.append(loss_t)
218    loss = (1 / n) * sum(losses)  # 文書全体の平均損失
219
220    # 逆伝播: 全パラメータに対する勾配を計算
221    loss.backward()
222
223    # Adam更新: 勾配に基づいてパラメータを更新
224    lr_t = learning_rate * (1 - step / num_steps)  # 線形学習率減衰
225    for i, p in enumerate(params):
226        m[i] = beta1 * m[i] + (1 - beta1) * p.grad
227        v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
228        m_hat = m[i] / (1 - beta1 ** (step + 1))  # バイアス補正
229        v_hat = v[i] / (1 - beta2 ** (step + 1))
230        p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)  # Adam更新式
231        p.grad = 0  # 次ステップ用に勾配をリセット
232
233    print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}")

23.3. 学習ループ内のブロック図と概要

次の図は、上記ソースの制御の流れを (1)~(7) の項番号に対応させたものです。(1)for step の外側で一度だけ実行され、(2)~(7) は各学習ステップで繰り返されます。(4) の中に pos_id による内側ループがあります。

diagram

図-1: 学習ループのブロックと項 (1)~(7) の対応

おおまかな役割

(1)

learning_rate / beta1 / beta2 / eps_adam の設定と、パラメータ数ぶんの第 1・第 2 モーメント mv の初期化。

(2)

num_steps 回の外側ループ。1 周で 1 文書分の順伝播・逆伝播・更新がまとまる。

(3)

docs から文書を選び BOS 付き tokens を作り、n で当ステップの予測位置数を決める。

(4)

各層の keys/values を空にし、各 pos_idgptsoftmax と正解トークンでの負の対数尤度を積み、平均 loss を得る。

(5)

loss.backward() で全パラメータの grad を計算する。

(6)

lr_t を線形に下げつつ、各 params[i] に Adam 更新を適用し、grad をリセットする。

(7)

人間向けに step+1loss.data を表示する。

23.4. (1) L194〜L196 — Adam 用ハイパーパラメータとモーメントのバッファ

194learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
195m = [0.0] * len(params)  # 第1モーメント(勾配の移動平均)
196v = [0.0] * len(params)  # 第2モーメント(勾配の二乗の移動平均)
  • L194 で学習率 learning_rate と Adam の beta1(第 1 モーメントの減衰)beta2(第 2 モーメントの減衰)、数値安定用の eps_adam を置きます。

  • L195〜L196 で、パラメータ 1 個につき 第 1 モーメント m[i]第 2 モーメント v[i] を用意し、初期値はいずれも 0 です。リストの長さは len(params)(学習対象の Value の個数)です。

23.5. (2) L201〜L202 — 学習ステップ数と外側のループ

201num_steps = 1000  # 学習ステップ数
202for step in range(num_steps):

num_steps 回、step = 0 num_steps-1 でループします。1 周のなかで「文書を取る → 損失 → 逆伝播 → Adam」がまとめて 1 ステップです。

23.6. (3) L205〜L207 — 文書の取得・トークン列・予測する位置数 n

205    doc = docs[step % len(docs)]
206    tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
207    n = min(block_size, len(tokens) - 1)  # 予測する位置数
  • L205docs[step % len(docs)] により、ステップごとに 1 文書を選びます(文書が何件あっても必ずどれかが選ばれる)。

  • L206 で文頭・文末に BOS を付けたトークン列 tokens を作ります(実装は上の literalinclude のとおり)。

  • L207n = min(block_size, len(tokens) - 1) とし、一度の順伝播で何文字分「次トークン」を予測するかの上限を決めます(コンテキスト長 block_size を超えない)。

23.7. (4) L210〜L218 — K/V キャッシュの初期化・位置ループ・平均損失

210    keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
211    losses = []
212    for pos_id in range(n):
213        token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
214        logits = gpt(token_id, pos_id, keys, values)
215        probs = softmax(logits)
216        loss_t = -probs[target_id].log()  # 正解トークンの負の対数尤度
217        losses.append(loss_t)
218    loss = (1 / n) * sum(losses)  # 文書全体の平均損失
  • L210 で各層の keys / values を空リストのリストで初期化します(この文書の順伝播を始める前にキャッシュを空にする)。

  • L212〜L217pos_id ごとに、現在トークン tokens[pos_id]正解の次トークン tokens[pos_id+1] を取り、gptsoftmax → 正解ラベルでの負の対数尤度losses に追加します。

  • L218n 個の損失の平均をスカラー loss にします(1 文書あたりの目的関数)。

23.8. (5) L221 — 逆伝播

221    loss.backward()

loss.backward() により、loss に関与した計算グラフを逆に辿り、各パラメータ pp.grad に勾配が入ります。

23.9. (6) L224〜L231 — 線形学習率減衰と Adam 更新

224    lr_t = learning_rate * (1 - step / num_steps)  # 線形学習率減衰
225    for i, p in enumerate(params):
226        m[i] = beta1 * m[i] + (1 - beta1) * p.grad
227        v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
228        m_hat = m[i] / (1 - beta1 ** (step + 1))  # バイアス補正
229        v_hat = v[i] / (1 - beta2 ** (step + 1))
230        p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)  # Adam更新式
231        p.grad = 0  # 次ステップ用に勾配をリセット
  • L224lr_t = learning_rate * (1 - step / num_steps) とし、ステップが進むほど学習率を線形に下げます。

  • L225〜L231 で各 params[i] について m[i], v[i] を更新し、バイアス補正付きの m_hat, v_hat から p.data を減算します(Adam の標準形)。

  • L231p.grad = 0 とし、次ステップのために勾配を消します。

23.10. (7) L233 — ログ出力

233    print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}")

現在のステップ番号(人間向けに step+1)と、loss.data(スカラーに化した平均損失)を表示します(例: step    1 / 1000 | loss 2.3456)。

要約の流れは次のとおりです。

順伝播
↓
損失計算
↓
逆伝播
↓
重み更新
diagram

図-2: 学習ループのメッセージの流れ

深層学習の基本が、そのまま小さい形で入っています。 上のブロック図は制御の流れ、下のシーケンス図はデータの往復の補足です。