23. microgpt の学習ループ
23.1. この章で学ぶこと
トークン列での損失、逆伝播、Adam 更新
18 章 B6 に対応する (1)〜(7) の読み方
23.2. 学習ループ
学習では、現在のトークンから次トークンを予測し、そのズレを損失として計算します。損失から逆伝播で得た勾配に基づき、パラメータを更新するのに Adam という最適化手法を使います(Adam とは何かは直下の注記)。まず Adam 用のバッファ宣言からログまで(microgpt.py L193〜L233)を引用します。続けて ブロック図と各ブロックの一行概要を置き、そのあと 項 (1)~(7) で論理ブロックごとに読みます。各ブロックのあとに同じ範囲の部分抜粋を重複して示します。数値例は **gpt(モデル本体)**の節と同様、理解用の仮定です。
注釈
Adam(Adaptive Moment Estimation)とは
ニューラルネットでは、損失を小さくするために各パラメータ(重み)の値を少しずつ変える必要があります。いちばん素朴なのは勾配降下で、各パラメータから「そのパラメータに関する損失の勾配」に比例して引きます。ところがパラメータごとに勾配の大きさや向きが違うため、どのパラメータをどれだけ動かすかをうまく調整したい場面が多く、いくつかの適応的な手法が提案されています。
Adamは、その代表のひとつで、各パラメータについて 勾配そのものの指数移動平均(第 1 モーメント、コードでは m)と、勾配の二乗の指数移動平均(第 2 モーメント、コードでは v)を蓄え、それらからパラメータごとの更新量のスケールを決めます。学習のはじめで m や v が小さすぎる偏りを直すバイアス補正(m_hat / v_hat)も、標準的な Adam の式に含まれます。
本コードの learning_rate(全体の学習率)、beta1 / beta2(モーメントの減衰率)、eps_adam(分母の安定化)は Adam のハイパーパラメータです。(1) で m / v のリストを用意し、(6) で各 params[i] に対して Adam の更新式を適用しています。
ハイパーパラメータ |
既定値 |
意味 |
|---|---|---|
|
0.01 |
全体の学習率 |
|
0.85 |
第 1 モーメント(勾配)の減衰率 |
|
0.99 |
第 2 モーメント(勾配²)の減衰率 |
|
1e-8 |
分母安定化用の小定数 |
193# -----------------------------------------------------------------------------
194learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
195m = [0.0] * len(params) # 第1モーメント(勾配の移動平均)
196v = [0.0] * len(params) # 第2モーメント(勾配の二乗の移動平均)
197
198# -----------------------------------------------------------------------------
199# 学習ループ
200# -----------------------------------------------------------------------------
201num_steps = 1000 # 学習ステップ数
202for step in range(num_steps):
203
204 # 1文書を取得し、トークン化。前後にBOSを付与
205 doc = docs[step % len(docs)]
206 tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
207 n = min(block_size, len(tokens) - 1) # 予測する位置数
208
209 # 順伝播: 各位置で次トークンを予測し、負の対数尤度(交差エントロピー)を計算
210 keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
211 losses = []
212 for pos_id in range(n):
213 token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
214 logits = gpt(token_id, pos_id, keys, values)
215 probs = softmax(logits)
216 loss_t = -probs[target_id].log() # 正解トークンの負の対数尤度
217 losses.append(loss_t)
218 loss = (1 / n) * sum(losses) # 文書全体の平均損失
219
220 # 逆伝播: 全パラメータに対する勾配を計算
221 loss.backward()
222
223 # Adam更新: 勾配に基づいてパラメータを更新
224 lr_t = learning_rate * (1 - step / num_steps) # 線形学習率減衰
225 for i, p in enumerate(params):
226 m[i] = beta1 * m[i] + (1 - beta1) * p.grad
227 v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
228 m_hat = m[i] / (1 - beta1 ** (step + 1)) # バイアス補正
229 v_hat = v[i] / (1 - beta2 ** (step + 1))
230 p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam) # Adam更新式
231 p.grad = 0 # 次ステップ用に勾配をリセット
232
233 print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}")
23.3. 学習ループ内のブロック図と概要
次の図は、上記ソースの制御の流れを (1)~(7) の項番号に対応させたものです。(1) は for step の外側で一度だけ実行され、(2)~(7) は各学習ステップで繰り返されます。(4) の中に pos_id による内側ループがあります。
図-1: 学習ループのブロックと項 (1)~(7) の対応
項 |
おおまかな役割 |
|---|---|
(1) |
|
(2) |
|
(3) |
|
(4) |
各層の |
(5) |
|
(6) |
|
(7) |
人間向けに |
23.4. (1) L194〜L196 — Adam 用ハイパーパラメータとモーメントのバッファ
194learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
195m = [0.0] * len(params) # 第1モーメント(勾配の移動平均)
196v = [0.0] * len(params) # 第2モーメント(勾配の二乗の移動平均)
L194 で学習率
learning_rateと Adam のbeta1(第 1 モーメントの減衰)、beta2(第 2 モーメントの減衰)、数値安定用のeps_adamを置きます。L195〜L196 で、パラメータ 1 個につき 第 1 モーメント
m[i]と 第 2 モーメントv[i]を用意し、初期値はいずれも 0 です。リストの長さはlen(params)(学習対象のValueの個数)です。
23.5. (2) L201〜L202 — 学習ステップ数と外側のループ
201num_steps = 1000 # 学習ステップ数
202for step in range(num_steps):
num_steps 回、step = 0 … num_steps-1 でループします。1 周のなかで「文書を取る → 損失 → 逆伝播 → Adam」がまとめて 1 ステップです。
23.6. (3) L205〜L207 — 文書の取得・トークン列・予測する位置数 n
205 doc = docs[step % len(docs)]
206 tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
207 n = min(block_size, len(tokens) - 1) # 予測する位置数
L205 で
docs[step % len(docs)]により、ステップごとに 1 文書を選びます(文書が何件あっても必ずどれかが選ばれる)。L206 で文頭・文末に BOS を付けたトークン列
tokensを作ります(実装は上のliteralincludeのとおり)。L207 で
n = min(block_size, len(tokens) - 1)とし、一度の順伝播で何文字分「次トークン」を予測するかの上限を決めます(コンテキスト長block_sizeを超えない)。
23.7. (4) L210〜L218 — K/V キャッシュの初期化・位置ループ・平均損失
210 keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
211 losses = []
212 for pos_id in range(n):
213 token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
214 logits = gpt(token_id, pos_id, keys, values)
215 probs = softmax(logits)
216 loss_t = -probs[target_id].log() # 正解トークンの負の対数尤度
217 losses.append(loss_t)
218 loss = (1 / n) * sum(losses) # 文書全体の平均損失
L210 で各層の
keys/valuesを空リストのリストで初期化します(この文書の順伝播を始める前にキャッシュを空にする)。L212〜L217 で
pos_idごとに、現在トークンtokens[pos_id]と 正解の次トークンtokens[pos_id+1]を取り、gpt→softmax→ 正解ラベルでの負の対数尤度をlossesに追加します。L218 で
n個の損失の平均をスカラーlossにします(1 文書あたりの目的関数)。
23.8. (5) L221 — 逆伝播
221 loss.backward()
loss.backward() により、loss に関与した計算グラフを逆に辿り、各パラメータ p の p.grad に勾配が入ります。
23.9. (6) L224〜L231 — 線形学習率減衰と Adam 更新
224 lr_t = learning_rate * (1 - step / num_steps) # 線形学習率減衰
225 for i, p in enumerate(params):
226 m[i] = beta1 * m[i] + (1 - beta1) * p.grad
227 v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
228 m_hat = m[i] / (1 - beta1 ** (step + 1)) # バイアス補正
229 v_hat = v[i] / (1 - beta2 ** (step + 1))
230 p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam) # Adam更新式
231 p.grad = 0 # 次ステップ用に勾配をリセット
L224 で
lr_t = learning_rate * (1 - step / num_steps)とし、ステップが進むほど学習率を線形に下げます。L225〜L231 で各
params[i]についてm[i],v[i]を更新し、バイアス補正付きのm_hat,v_hatからp.dataを減算します(Adam の標準形)。L231 で
p.grad = 0とし、次ステップのために勾配を消します。
23.10. (7) L233 — ログ出力
233 print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}")
現在のステップ番号(人間向けに step+1)と、loss.data(スカラーに化した平均損失)を表示します(例: step 1 / 1000 | loss 2.3456)。
要約の流れは次のとおりです。
順伝播
↓
損失計算
↓
逆伝播
↓
重み更新
図-2: 学習ループのメッセージの流れ
深層学習の基本が、そのまま小さい形で入っています。 上のブロック図は制御の流れ、下のシーケンス図はデータの往復の補足です。