20. microgpt の主要データ構造

20.1. この章で学ぶこと

  • 自動微分のための Value と計算グラフ

  • state_dictparams、行列の初期化

  • 18 章B3B4 に対応するコード

20.2. Value の直感(履歴つきのスカラー)

Value は自動微分の最小単位です。 スカラー 1 個に計算履歴を持たせ、backward() で勾配を流します。data(値)、grad(勾配)、_children(入力ノード)、_local_grads(局所勾配)をまとめて持ち、この数字がどこから来たかを覚えています。順伝播の計算を行うと同時にその履歴を記録し、backward() を呼ぶことで末端から根へ向かって勾配を足し戻すことができます。

loss.backward() は、計算の履歴をトポロジカル順に逆順にたどる処理です。build_topo() は依存を崩さない順に並べるための準備にすぎません。

20.2.1. 自動微分のための Value

Value は、このファイルの土台です。PyTorch の Tensor のごく小さな対応物で、スカラー(1 つの実数)だけを扱います。ベクトルや行列は「Value のリストのリスト」として表現します。

53class Value:
54    __slots__ = ('data', 'grad', '_children', '_local_grads')  # メモリ最適化
55
56    def __init__(self, data, children=(), local_grads=()):
57        self.data = data                # 順伝播で計算されたスカラー値
58        self.grad = 0                   # 損失に対するこのノードの勾配(逆伝播で計算)
59        self._children = children       # 計算グラフ上の子ノード
60        self._local_grads = local_grads # 子ノードに対する局所勾配(連鎖律用)
61
62    def __add__(self, other):
63        # 加算: d(a+b)/da=1, d(a+b)/db=1
64        other = other if isinstance(other, Value) else Value(other)
65        return Value(self.data + other.data, (self, other), (1, 1))
66
67    def __mul__(self, other):
68        # 乗算: d(a*b)/da=b, d(a*b)/db=a
69        other = other if isinstance(other, Value) else Value(other)
70        return Value(self.data * other.data, (self, other), (other.data, self.data))
71
72    def __pow__(self, other): return Value(self.data**other, (self,), (other * self.data**(other-1),))
73    def log(self): return Value(math.log(self.data), (self,), (1/self.data,))
74    def exp(self): return Value(math.exp(self.data), (self,), (math.exp(self.data),))
75    def relu(self): return Value(max(0, self.data), (self,), (float(self.data > 0),))  # ReLU: max(0,x)
76    def __neg__(self): return self * -1
77    def __radd__(self, other): return self + other
78    def __sub__(self, other): return self + (-other)
79    def __rsub__(self, other): return other + (-self)
80    def __rmul__(self, other): return self * other
81    def __truediv__(self, other): return self * other**-1
82    def __rtruediv__(self, other): return other * self**-1
83
84    def backward(self):
85        # トポロジカルソートで計算グラフを逆順に辿り、連鎖律で勾配を伝播
86        topo = []
87        visited = set()
88        def build_topo(v):
89            if v not in visited:
90                visited.add(v)
91                for child in v._children:
92                    build_topo(child)
93                topo.append(v)
94        build_topo(self)
95        self.grad = 1  # 損失自身の勾配は1(dL/dL=1)
96        for v in reversed(topo):
97            for child, local_grad in zip(v._children, v._local_grads):
98                child.grad += local_grad * v.grad  # 連鎖律: ∂L/∂child += (∂v/∂child) * (∂L/∂v)

行ごとの意味

  • L53 class Value: — 自動微分可能なスカラーを表すクラス。

  • L54 __slots__ = (...) — インスタンス属性を固定し、インスタンスごとの __dict__ を持たせない(大量ノード時の省メモリ)。

  • L56–L60 __init__data に順伝播の値、grad は 0 から、children に入力ノードのタプル、local_grads に各入力に対する ∂(自分)/∂(入力)。

  • L62–L65 __add__ — 相手を Value に揃え、和の値と局所微分 (1, 1) を記録した新ノードを返す。

  • L67–L70 __mul__ — 積の値と、乗法の微分に対応する局所微分 (other.data, self.data) を記録。

  • L72 __pow__ — 累乗と、その入力に対する微分係数を 1 行で定義。

  • L73–L75 log / exp / relu — 各活性化・非線形の順伝播値と、入力への局所微分(log1/xrelu0 または 1)。

  • L76–L82 符号反転・減算・除算など — 既存の +* などに還元する演算子オーバーロード。

  • L84–L98 backwardbuild_topo で子→親の順に並べたあと逆順に辿り、self.grad = 1 から child.grad += local_grad * v.grad で連鎖律を適用。

要約すると、__add____mul__ は新しい Value を返すと同時に親子リンクと局所微分を保存し、backward()末端から根へ勾配を足し戻します(同じ Value が複数経路に現れる場合は += で合算)。数式では次の 1 行に相当します。

child.grad += local_grad * v.grad

普通の float は値しか持ちませんが、Value計算グラフを保持するので、loss までつながったグラフから各パラメータの勾配を求められます。

20.2.1.1. 最小例で動きを追う

式だけだと抽象的なので、スカラーが 2 個と乗算 1 回だけの例で、順伝播と逆伝播を数で対応させます。上に掲げた Value クラス(microgpt.py と同じ定義)を読み込んだ前提の試用例です。

a = Value(2.0)
b = Value(3.0)
c = a * b       # 順伝播: c.data = 2 * 3 = 6
c.backward()    # 「c を最後のスカラー」とみなし、∂c/∂a と ∂c/∂b を求める(学習時は c の代わりに loss)
# backward 後: c.grad == 1      # 実装では末端で d(末端)/d(末端)=1 を立てる
# backward 後: a.grad == 3.0    # ∂c/∂a = b.data
# backward 後: b.grad == 2.0    # ∂c/∂b = a.data

乗算では _local_grads(b.data, a.data)、つまり (3, 2) です。backward() の先頭で c.grad = 1 としたうえで、逆順に子へ流すと、

  • a.grad += 3 * c.grad = 3(∂c/∂a = b と一致)

  • b.grad += 2 * c.grad = 2(∂c/∂b = a と一致)

となり、手で計算した偏微分と一致します。学習では「最後のスカラー」が損失 loss になり、loss.backward()∂loss/∂(各パラメータ) が同じ仕組みで state_dict 内の各 Value.grad に溜まります。

加算だけの例も対比用に置きます。

x = Value(2.0)
y = Value(3.0)
s = x + y       # 順伝播: s.data = 5。局所微分は (1, 1)
s.backward()
# backward 後: s.grad == 1
# backward 後: x.grad == 1.0    # ∂s/∂x = 1
# backward 後: y.grad == 1.0    # ∂s/∂y = 1

つまり 順伝播で data を組み立て、逆伝播で grad を末端から根に向かって足し戻すのが Value の動きです。モデル全体はこの繰り返しが長いだけです。

20.2.2. 重み

次は重みです。実装では辞書 state_dict にまとめられ、各要素はすべて Value のスカラーです。行列は「行のリストのリスト」として表現し、matrix(nout, nin)nout × nin 個の Value をガウス乱数で初期化しています。

103n_layer = 1     # Transformerの層数(深さ)
104n_embd = 16     # 埋め込み次元(ネットワークの幅)
105block_size = 16 # コンテキスト長の上限(注意窓の最大長。最長名前は15文字)
106n_head = 4      # マルチヘッドアテンションのヘッド数
107head_dim = n_embd // n_head  # 各ヘッドの次元(n_embdをn_headで分割)
108matrix = lambda nout, nin, std=0.08: [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)]
109# パラメータ辞書: wte=トークン埋め込み, wpe=位置埋め込み, lm_head=言語モデル出力層
110state_dict = {'wte': matrix(vocab_size, n_embd), 'wpe': matrix(block_size, n_embd), 'lm_head': matrix(vocab_size, n_embd)}
111for i in range(n_layer):
112    # Attention: Q/K/V/O の4つの線形変換(Query, Key, Value, Output)
113    state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd)
114    state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd)
115    state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd)
116    state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd)
117    # MLP: 2層の全結合(中間層は4倍に拡張)
118    state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd)
119    state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd)
120params = [p for mat in state_dict.values() for row in mat for p in row]  # 全パラメータを1次元リストに平坦化
121print(f"num params: {len(params)}")

行ごとの意味

  • L103–L106 ハイパーパラメータ — 層数・埋め込み次元・最大文脈長・ヘッド数を定義する。

  • L107 head_dim = n_embd // n_head — 埋め込み次元をヘッド数で割り、各ヘッドの部分空間の次元にする。

  • L108 matrix = lambda ...noutnin 列の Value 行列を、標準偏差 std のガウス乱数で初期化するヘルパー。

  • L110 state_dict = { ... } — まず wte(語彙×埋め込み)、wpe(最大位置×埋め込み)、lm_head(語彙×埋め込み)の 3 つを登録する。

  • L111–L119 for i in range(n_layer) — 各層に Attention 用 4 行列(Q/K/V/O)と MLP 用 2 行列を追加。Attention は n_embd × n_embd、MLP は中間を 4 倍に広げる 4*n_embd × n_embd と、その逆形状の mlp_fc2

  • L120 params = [...]state_dict 内のすべての行列を走査し、含まれる Value一次元リストに平坦化(Adam がこの順で更新する)。

  • L121 パラメータ数の表示。

意味づけの対応は次のとおりです。

  • wte: 語彙サイズ vocab_size 行、n_embd 列。トークン ID から埋め込みベクトルを引く。

  • wpe: 位置 0 block_size-1 まで n_embd 次元。文の先頭から何文字目かを表す。

  • layer{i}.attn_wq / attn_wk / attn_wv / attn_wo: Query・Key・Value・Attention 出力への線形変換(いずれも n_embd × n_embd)。

  • layer{i}.mlp_fc1 / mlp_fc2: MLP の 2 層(中間 4 倍)。

  • lm_head: 最終表現から語彙サイズ次元の logits へ写す vocab_size × n_embd

学習で更新するのは params に並んだ Value だけです。Adam はこのリストと同じ順で勾配を読み、Value.data を更新します。

最初は小さな乱数で始めます。ここで大事なのは、モデルは最初から賢いわけではないということです。学習しながら少しずつ文字のつながり方を覚えていきます。