21. microgpt のユーティリティ関数
21.1. この章で学ぶこと
linear/softmax/rmsnormの役割と実装の読み方22 章 の
gptから何度も呼ばれる部品としての位置づけ
21.1.1. linear
入力ベクトル x と重み行列 w から、行ごとに内積(線形変換)を計算します。出力の各成分は Value の演算でつながります。
129def linear(x, w):
130 """線形変換: y = Wx。入力xと重みwから出力ベクトルを計算"""
131 return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w]
行ごとの意味
L129
def linear(x, w):— 入力ベクトルxと重みw(行が出力次元、列が入力次元のリストのリスト)を受け取る。L130 ドキュストリング —
y = Wxの対応関係を説明。L131 戻り値 — 各行
woについてsum(wi * xi)をValueのまま計算し、出力ベクトルを返す。
最小例で Value のつながりを追う
入力が 2 次元で、出力が 1 ニューロンだけのとき、重みは 1 行 2 列の行列です。たとえば x = [Value(1.0), Value(2.0)]、w = [[Value(3.0), Value(4.0)]] とすると、linear(x, w) の唯一の成分は
y[0] = 3×1 + 4×2 = 11
となります(実際の .data はこの値)。式の中身はすべて Value の * と + なので、y[0] は x と w の各要素を子に持つ計算グラフの根になります。損失から backward() すると、この内積に関わる重み・入力方向へ勾配が流れます。次元を増やしても「行ごとに同じ内積」が並ぶだけです。
21.1.2. softmax
ここでいう logits(ロジット) は、「まだ確率になっていない、各候補に対する生のスコア」の列です。語彙サイズが N なら長さ N のベクトルで、成分が大きいほど「そのトークンらしい」という強さを表します。スコアは負でもよく、足して 1 になる必要もありません。学習では lm_head などの線形層が、隠れ状態からこのスコア列を計算します。softmax は、その logits を 0 以上で、すべて足すと 1 になる確率に変換する役割です。
この関数は logits のベクトルを確率分布に変換します。最大値を引いてから指数するので、オーバーフローを抑えた実装です。
133def softmax(logits):
134 """ソフトマックス: logitsを確率分布に変換。数値安定性のためmaxで引いてからexp"""
135 max_val = max(val.data for val in logits)
136 exps = [(val - max_val).exp() for val in logits]
137 total = sum(exps)
138 return [e / total for e in exps]
行ごとの意味
L133
def softmax(logits):—Valueのリストを受け取る。L134 ドキュストリング — logits を確率化し、
maxで引いてからexpする方針。L135
max_val = max(val.data for val in logits)— 全 logits の最大(スカラー)を取る。L136
exps = [...]— 各 logit からmax_valを引いてからexp(Valueの演算)。L137
total = sum(exps)— 正規化の分母。L138 各指数を
totalで割り、和が 1 になる確率ベクトルを返す。
最小例([Value(1.0), Value(2.0)]):
exps[0] ∝ exp(1 - 2) = exp(-1), exps[1] ∝ exp(2 - 2) = exp(0) = 1
total = exps[0] + exps[1]
出力[i] = exps[i] / total (和は 1)
2 の方が大きいので確率もより大きくなります。exp / + / / は Value の演算なので確率ベクトルが損失へつながります。
21.1.3. rmsnorm
ベクトルの RMS(二乗平均の平方根)でスケールし、ベクトルの大きさを揃えます。同じ「層の中でベクトルを正規化する」系に LayerNorm がありますが、本コードの RMSNorm はそこから成分の平均を引く計算を省いた簡略版です。LayerNorm そのものの説明は、この項の末尾の注記にまとめています。
140def rmsnorm(x):
141 """RMSNorm: 二乗平均の平方根で正規化。LayerNormの簡略版(平均を引かない)"""
142 ms = sum(xi * xi for xi in x) / len(x)
143 scale = (ms + 1e-5) ** -0.5
144 return [xi * scale for xi in x]
行ごとの意味
L140
def rmsnorm(x):—Valueのリスト(ベクトル)を受け取る。L141 ドキュストリング — RMS 正規化(平均を引かない)の説明。LayerNorm との対比はこの項の注記参照。
L142
ms = sum(xi * xi for xi in x) / len(x)— 二乗の平均。L143
scale = (ms + 1e-5) ** -0.5— 小さな定数1e-5で除算を安定化し、RMS の逆数を得る。L144 各成分に
scaleを掛けて返す。
最小例([Value(3.0), Value(4.0)]):
ms = (3² + 4²) / 2 = 12.5
scale = (12.5 + 1e-5)^(-1/2) ≈ 1 / √12.5
出力[i] = x[i] * scale
各成分に 1/RMS を掛けるので正規化後の二乗平均がおおよそ 1 になり、大きすぎる活性を抑えます。xi * scale で各成分がグラフに入ります。
注釈
LayerNorm とは
Layer Normalization(レイヤー正規化)は、ニューラルネットのある一層の出力ベクトルに対して、そのベクトルの成分だけを材料に「位置をそろえてから、幅をそろえる」処理です。Transformer では、トークン位置ごとに隠れベクトル(次元 n_embd)がひとつずつあり、そのベクトルに対して正規化をかけます。
ざっくりしたイメージは次のとおりです。
中心化 — そのベクトルの成分の平均を求め、各成分から引く(ベクトルの重心を原点付近へ)。
スケーリング — 成分の分散(または標準偏差)で割る(ばらつきの大きさをおおよそ 1 に)。
(実装によっては)学習可能な倍率とバイアスを掛けて、モデルが正規化の強さを調整できるようにする。
目的は、層を重ねても中間の値が極端に大きくなったり小さくなりすぎたりしにくくし、勾配が安定して流れるようにすることです。
RMSNorm は、このうち「中心化(平均を引く)」を省略し、二乗平均の平方根(RMS)だけで割ってスケールを揃えます。式は単純になり、計算も軽く、GPT-2 に近い設定の本コードでも採用されています。microgpt.py のモデル部コメント(GPT-2 との相違点として LayerNorm→RMSNorm と書かれている箇所)が指すのも、この違いです。