21. microgpt のユーティリティ関数

21.1. この章で学ぶこと

  • linear / softmax / rmsnorm の役割と実装の読み方

  • 22 章gpt から何度も呼ばれる部品としての位置づけ

21.1.1. linear

入力ベクトル x と重み行列 w から、行ごとに内積(線形変換)を計算します。出力の各成分は Value の演算でつながります。

129def linear(x, w):
130    """線形変換: y = Wx。入力xと重みwから出力ベクトルを計算"""
131    return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w]

行ごとの意味

  • L129 def linear(x, w): — 入力ベクトル x と重み w(行が出力次元、列が入力次元のリストのリスト)を受け取る。

  • L130 ドキュストリング — y = Wx の対応関係を説明。

  • L131 戻り値 — 各行 wo について sum(wi * xi)Value のまま計算し、出力ベクトルを返す。

最小例で Value のつながりを追う

入力が 2 次元で、出力が 1 ニューロンだけのとき、重みは 1 行 2 列の行列です。たとえば x = [Value(1.0), Value(2.0)]w = [[Value(3.0), Value(4.0)]] とすると、linear(x, w) の唯一の成分は

y[0] = 3×1 + 4×2 = 11

となります(実際の .data はこの値)。式の中身はすべて Value*+ なので、y[0]xw の各要素を子に持つ計算グラフの根になります。損失から backward() すると、この内積に関わる重み・入力方向へ勾配が流れます。次元を増やしても「行ごとに同じ内積」が並ぶだけです。

21.1.2. softmax

ここでいう logits(ロジット) は、「まだ確率になっていない、各候補に対する生のスコア」の列です。語彙サイズが N なら長さ N のベクトルで、成分が大きいほど「そのトークンらしい」という強さを表します。スコアは負でもよく、足して 1 になる必要もありません。学習では lm_head などの線形層が、隠れ状態からこのスコア列を計算します。softmax は、その logits を 0 以上で、すべて足すと 1 になる確率に変換する役割です。

この関数は logits のベクトルを確率分布に変換します。最大値を引いてから指数するので、オーバーフローを抑えた実装です。

133def softmax(logits):
134    """ソフトマックス: logitsを確率分布に変換。数値安定性のためmaxで引いてからexp"""
135    max_val = max(val.data for val in logits)
136    exps = [(val - max_val).exp() for val in logits]
137    total = sum(exps)
138    return [e / total for e in exps]

行ごとの意味

  • L133 def softmax(logits):Value のリストを受け取る。

  • L134 ドキュストリング — logits を確率化し、max で引いてから exp する方針。

  • L135 max_val = max(val.data for val in logits) — 全 logits の最大(スカラー)を取る。

  • L136 exps = [...] — 各 logit から max_val を引いてから expValue の演算)。

  • L137 total = sum(exps) — 正規化の分母。

  • L138 各指数を total で割り、和が 1 になる確率ベクトルを返す。

最小例([Value(1.0), Value(2.0)]):

exps[0] ∝ exp(1 - 2) = exp(-1),   exps[1] ∝ exp(2 - 2) = exp(0) = 1
total = exps[0] + exps[1]
出力[i] = exps[i] / total   (和は 1)

2 の方が大きいので確率もより大きくなります。exp / + / /Value の演算なので確率ベクトルが損失へつながります。

21.1.3. rmsnorm

ベクトルの RMS(二乗平均の平方根)でスケールし、ベクトルの大きさを揃えます。同じ「層の中でベクトルを正規化する」系に LayerNorm がありますが、本コードの RMSNorm はそこから成分の平均を引く計算を省いた簡略版です。LayerNorm そのものの説明は、この項の末尾の注記にまとめています。

140def rmsnorm(x):
141    """RMSNorm: 二乗平均の平方根で正規化。LayerNormの簡略版(平均を引かない)"""
142    ms = sum(xi * xi for xi in x) / len(x)
143    scale = (ms + 1e-5) ** -0.5
144    return [xi * scale for xi in x]

行ごとの意味

  • L140 def rmsnorm(x):Value のリスト(ベクトル)を受け取る。

  • L141 ドキュストリング — RMS 正規化(平均を引かない)の説明。LayerNorm との対比はこの項の注記参照。

  • L142 ms = sum(xi * xi for xi in x) / len(x) — 二乗の平均。

  • L143 scale = (ms + 1e-5) ** -0.5 — 小さな定数 1e-5 で除算を安定化し、RMS の逆数を得る。

  • L144 各成分に scale を掛けて返す。

最小例([Value(3.0), Value(4.0)]):

ms = (3² + 4²) / 2 = 12.5
scale = (12.5 + 1e-5)^(-1/2) ≈ 1 / √12.5
出力[i] = x[i] * scale

各成分に 1/RMS を掛けるので正規化後の二乗平均がおおよそ 1 になり、大きすぎる活性を抑えます。xi * scale で各成分がグラフに入ります。

注釈

LayerNorm とは

Layer Normalization(レイヤー正規化)は、ニューラルネットのある一層の出力ベクトルに対して、そのベクトルの成分だけを材料に「位置をそろえてから、幅をそろえる」処理です。Transformer では、トークン位置ごとに隠れベクトル(次元 n_embd)がひとつずつあり、そのベクトルに対して正規化をかけます。

ざっくりしたイメージは次のとおりです。

  1. 中心化 — そのベクトルの成分の平均を求め、各成分から引く(ベクトルの重心を原点付近へ)。

  2. スケーリング — 成分の分散(または標準偏差)で割る(ばらつきの大きさをおおよそ 1 に)。

  3. (実装によっては)学習可能な倍率とバイアスを掛けて、モデルが正規化の強さを調整できるようにする。

目的は、層を重ねても中間の値が極端に大きくなったり小さくなりすぎたりしにくくし、勾配が安定して流れるようにすることです。

RMSNorm は、このうち「中心化(平均を引く)」を省略し、二乗平均の平方根(RMS)だけで割ってスケールを揃えます。式は単純になり、計算も軽く、GPT-2 に近い設定の本コードでも採用されています。microgpt.py のモデル部コメント(GPT-2 との相違点として LayerNormRMSNorm と書かれている箇所)が指すのも、この違いです。