25. microgpt.py を Mojo で書き直す

25.1. この章で学ぶこと

  • microgpt.pyB1〜B7 ブロックが Mojo でどう書かれているか、ソースコードを読みながら理解する

  • Mojo の所有権・テープ(計算グラフ)・構造体を、動くコードの上で確認する

  • Python の Value クラスを Mojo の Tape(インデックスベース)に置き換えた設計上の理由を知る

  • main.mojo で B1〜B7 を組み合わせた学習・推論の全体像を把握する


25.2. 実行例

src/part3/microgpt_mojo/ で次のように実行します。

cd src/part3/microgpt_mojo
mojo run main.mojo

出力(学習ステップは途中を省略):

num docs: 32033
vocab size: 27
num params: 4192
step    1 / 1000 | loss 3.2589016467713185
step    2 / 1000 | loss 3.4757539661271393
step    3 / 1000 | loss 3.2850563256139123
step    4 / 1000 | loss 3.2385487413715050
step    5 / 1000 | loss 3.4347922579358574
        ...
step  996 / 1000 | loss 1.9243386681471330
step  997 / 1000 | loss 2.2576022798634730
step  998 / 1000 | loss 1.6500781354343423
step  999 / 1000 | loss 1.9544453781129398
step 1000 / 1000 | loss 2.5348443763481185

--- inference ---
sample  1 : ala
sample  2 : nayna
sample  3 : anare
sample  4 : alania
sample  5 : odelen
sample  6 : sona
sample  7 : adila
sample  8 : kais
sample  9 : tayna
sample 10 : adelel
sample 11 : atena
sample 12 : koncy
sample 13 : maneza
sample 14 : aleran
sample 15 : amelin
sample 16 : tirha
sample 17 : saria
sample 18 : tanin
sample 19 : jaian
sample 20 : astyan

loss が 3.2 前後から 1.6〜2.5 程度まで下がり、推論では名前らしい文字列(“ala”, “sona”, “adelel” など)が生成されていることが確認できます。


25.3. 全体構造

25.3.1. B1〜B7 の依存関係

各ブロックは以下のように依存しています。矢印は「使う→使われる」の方向です。

diagram

B3(Tape)がすべての土台です。演算のたびにノードをテープに追記し、逆伝播で勾配を計算します。B4〜B7 はテープ上のノード ID(Int)を受け渡しながら動きます。

25.3.2. ファイル構成

ファイル

対応する Python の概念

主な内容

b01_dataset.mojo

docs リスト構築

テキストから文書リストを作る

b02_tokenizer.mojo

uchars / BOS / トークン列

文字→IDの変換

b03_value.mojo

Value クラス / backward

自動微分テープ

b04_state_dict.mojo

state_dict / params

モデル重みの初期化・管理

b05_ops.mojo

linear / rmsnorm / softmax

テープ上のテンソル演算

b05_gpt.mojo

gpt() 関数

GPT 1ステップの順伝播

b06_train.mojo

学習ループの核

損失計算・Adam 更新

b07_infer.mojo

推論ループの核

温度付きサンプリング

main.mojo

スクリプト本体

全体を組み合わせた実行ファイル

25.3.3. データとコードの流れ

diagram

25.4. B1 — データセット(b01_dataset.mojo)

25.4.1. ソースコード

# B1: Build non-empty stripped lines from multi-line text (like input.txt).


def load_docs_from_text(text: String) -> List[String]:
    var docs = List[String]()
    var lines = text.splitlines()
    for i in range(len(lines)):
        var line = String(String(lines[i]).strip())
        if len(line) > 0:
            docs.append(line)
    return docs^


def load_sample_names() -> List[String]:
    # Tiny fixed corpus for tests (ASCII names, one per line).
    return load_docs_from_text(String("ada\nbob\n\nada\n"))

25.4.2. 解説

load_docs_from_textinput.txt の中身(文字列)を受け取り、空行を除いた行のリストを返します。

var lines = text.splitlines()          # 改行で分割
var line = String(String(lines[i]).strip())  # 前後の空白を除去
if len(line) > 0:                      # 空行をスキップ
    docs.append(line)

String(String(lines[i]).strip()) と 2 重に String() が現れているのは、型変換が 2 段階必要なためです。

説明

lines[i]

StringSlice

splitlines() の要素(元テキストへの参照)

String(lines[i])

String

所有権のあるコピーに変換

String(lines[i]).strip()

StringSlice

前後の空白を除いた部分への参照(所有権なし)

String(String(lines[i]).strip())

String

所有権のあるコピーに変換 ← これを append

strip() の戻り値が StringSlice(参照)であることが核心です。List[String]append するには所有権のある String が必要なので、外側の String(...) でもう一度コピーを作っています。

関数の末尾で return docs^ と書いているのは、docs をコピーせず転送(move)して返すためです。^ は「所有権をここで手放す」という演算子で、不要なコピーを避けます。

load_sample_names はテスト用の小さなデータセット(3件の名前)を固定文字列から作る便利関数です。


25.5. B2 — トークナイザー(b02_tokenizer.mojo)

25.5.1. ソースコード

# B2: Char-level tokenizer (sorted unique chars + BOS), matching microgpt.py.

from std.builtin.sort import sort


struct Tokenizer:
    var uchars: List[String]
    var bos: Int

    def __init__(out self, uchars: List[String], bos: Int):
        self.uchars = uchars.copy()
        self.bos = bos

    def vocab_size(self) -> Int:
        return len(self.uchars) + 1


def _char_at(doc: String, j: Int) -> String:
    return chr(Int(doc.as_bytes()[j]))


def _find_char(uchars: List[String], ch: String) -> Int:
    for i in range(len(uchars)):
        if uchars[i] == ch:
            return i
    return -1


def build_tokenizer(docs: List[String]) -> Tokenizer:
    var chars = List[String]()
    for di in range(len(docs)):
        var d = docs[di]
        for j in range(len(d)):
            var ch = _char_at(d, j)
            if _find_char(chars, ch) < 0:
                chars.append(ch)
    sort(chars)
    var bos = len(chars)
    return Tokenizer(chars^, bos)


def encode_doc(tok: Tokenizer, doc: String) -> List[Int]:
    var out = List[Int]()
    out.append(tok.bos)
    for j in range(len(doc)):
        var ch = _char_at(doc, j)
        var ix = _find_char(tok.uchars, ch)
        out.append(ix)
    out.append(tok.bos)
    return out^

25.5.2. 解説

25.5.2.1. Tokenizer 構造体

struct Tokenizer:
    var uchars: List[String]   # 語彙: ソート済みユニーク文字の一覧
    var bos: Int               # BOS トークンの ID(= len(uchars))

uchars['a', 'b', ..., 'z'] のような文字リストで、インデックスがそのままトークン ID になります。BOS(Beginning of Sequence) は文の先頭・末尾を示す特殊トークンで、文字数より 1 大きい ID を持ちます。

25.5.2.2. トークン化の例

文書 "ada" をトークン化すると次のようになります。

diagram

encode_doc[BOS, id('a'), id('d'), id('a'), BOS] というリストを作ります。学習時は位置 t のトークンから位置 t+1 のトークンを予測するタスクになります。

25.5.2.3. 文字の取り出し

def _char_at(doc: String, j: Int) -> String:
    return chr(Int(doc.as_bytes()[j]))

Python では doc[j:j+1] でスライスできましたが、Mojo 0.26 では String のスライスが直接 String を返さないため、as_bytes() でバイト列を取り出し chr() で文字に変換しています(ASCII 文字専用の教材向けの実装です)。

25.5.2.4. バブルソート

Python の sorted(set(...)) に相当する処理は from std.builtin.sort import sort でインポートした sort() で行います。sort(chars) と呼ぶだけで List[String] をインプレースにソートできます。


25.6. B3 — 自動微分テープ(b03_value.mojo)

25.6.1. ソースコード

# B3: Scalar autograd tape (microgpt.py Value), topology-order backward.
# Stored as structure-of-arrays so nodes stay in List[...] without custom copy/move.

from std.math import exp, log


comptime NodeId = Int


struct Tape:
    var data: List[Float64]
    var grad: List[Float64]
    var children: List[List[Int]]
    var local_grads: List[List[Float64]]

    def __init__(out self):
        self.data = List[Float64]()
        self.grad = List[Float64]()
        self.children = List[List[Int]]()
        self.local_grads = List[List[Float64]]()

    def _push_node(
        mut self, value: Float64, ch: List[Int], lg: List[Float64]
    ) -> NodeId:
        self.data.append(value)
        self.grad.append(0.0)
        self.children.append(ch.copy())
        self.local_grads.append(lg.copy())
        return len(self.data) - 1

    def node_data(self, i: NodeId) -> Float64:
        return self.data[i]

    def grad_at(self, i: NodeId) -> Float64:
        return self.grad[i]

    def set_data(mut self, i: NodeId, v: Float64):
        self.data[i] = v

    def set_grad(mut self, i: NodeId, g: Float64):
        self.grad[i] = g

    def add_grad(mut self, i: NodeId, d: Float64):
        self.grad[i] += d

    def clear_grads(mut self):
        for i in range(len(self.grad)):
            self.grad[i] = 0.0

    def leaf(mut self, value: Float64) -> NodeId:
        var ch = List[Int]()
        var lg = List[Float64]()
        return self._push_node(value, ch^, lg^)

    def add(mut self, a: NodeId, b: NodeId) -> NodeId:
        var ch = List[Int]()
        ch.append(a)
        ch.append(b)
        var lg = List[Float64]()
        lg.append(1.0)
        lg.append(1.0)
        var s = self.data[a] + self.data[b]
        return self._push_node(s, ch^, lg^)

    def mul(mut self, a: NodeId, b: NodeId) -> NodeId:
        var ch = List[Int]()
        ch.append(a)
        ch.append(b)
        var lg = List[Float64]()
        lg.append(self.data[b])
        lg.append(self.data[a])
        var p = self.data[a] * self.data[b]
        return self._push_node(p, ch^, lg^)

    def neg(mut self, a: NodeId) -> NodeId:
        var ch = List[Int]()
        ch.append(a)
        var lg = List[Float64]()
        lg.append(-1.0)
        return self._push_node(-self.data[a], ch^, lg^)

    def sub(mut self, a: NodeId, b: NodeId) -> NodeId:
        return self.add(a, self.neg(b))

    def pow_const(mut self, a: NodeId, expv: Float64) -> NodeId:
        var ad = self.data[a]
        var y = ad**expv
        var lg = expv * ad ** (expv - 1.0)
        var ch = List[Int]()
        ch.append(a)
        var lgs = List[Float64]()
        lgs.append(lg)
        return self._push_node(y, ch^, lgs^)

    def div(mut self, a: NodeId, b: NodeId) -> NodeId:
        var invb = self.pow_const(b, -1.0)
        return self.mul(a, invb)

    def log_(mut self, a: NodeId) -> NodeId:
        var ad = self.data[a]
        var ch = List[Int]()
        ch.append(a)
        var lg = List[Float64]()
        lg.append(1.0 / ad)
        return self._push_node(log(ad), ch^, lg^)

    def exp_(mut self, a: NodeId) -> NodeId:
        var ad = self.data[a]
        var y = exp(ad)
        var ch = List[Int]()
        ch.append(a)
        var lg = List[Float64]()
        lg.append(y)
        return self._push_node(y, ch^, lg^)

    def relu(mut self, a: NodeId) -> NodeId:
        var ad = self.data[a]
        var y = ad if ad > 0.0 else 0.0
        var g = 1.0 if ad > 0.0 else 0.0
        var ch = List[Int]()
        ch.append(a)
        var lg = List[Float64]()
        lg.append(g)
        return self._push_node(y, ch^, lg^)

    def sum_nodes(mut self, xs: List[Int]) -> NodeId:
        if len(xs) == 0:
            return self.leaf(0.0)
        var s = xs[0]
        for i in range(1, len(xs)):
            s = self.add(s, xs[i])
        return s

    def scale(mut self, a: NodeId, s: Float64) -> NodeId:
        return self.mul(a, self.leaf(s))


def _build_topo(mut tape: Tape, v: NodeId, mut topo: List[Int], mut visited: List[Bool]):
    if visited[v]:
        return
    visited[v] = True
    var ch = tape.children[v].copy()
    for i in range(len(ch)):
        _build_topo(tape, ch[i], topo, visited)
    topo.append(v)


def backward(mut tape: Tape, root: NodeId):
    tape.clear_grads()
    var visited = List[Bool]()
    for _ in range(len(tape.data)):
        visited.append(False)
    var topo = List[Int]()
    _build_topo(tape, root, topo, visited)
    tape.set_grad(root, 1.0)
    var ti = len(topo)
    while ti > 0:
        ti -= 1
        var v = topo[ti]
        var gv = tape.grad[v]
        var ch = tape.children[v].copy()
        var lgs = tape.local_grads[v].copy()
        for k in range(len(ch)):
            var c = ch[k]
            var lg = lgs[k]
            tape.add_grad(c, lg * gv)

25.6.2. Tape の仕組み

これが最も重要なモジュールです。GPT の学習には「どのパラメータをどの方向に動かすか」を知るための勾配計算が必要です。Tape はその計算グラフを管理します。

注釈

「テープ」という名前の由来

「テープ」は自動微分の歴史的な用語です。1964 年に R. E. Wengert が提案した「演算を順番に書き留めたリスト(Wengert list)」が起源で、テープレコーダーが音声を順に記録するように、順伝播中のすべての演算を記録していくことからこう呼ばれます。この方式を「テープベース自動微分(tape-based AD)」といいます。

TensorFlow には同じ概念の tf.GradientTape という API があり、PyTorch 内部でも “gradient tape” という用語が使われています。

Python 版の Value が「微分可能な値 1 個」を表す名前なのに対し、Mojo 版の Tape は「演算の記録装置全体」を表す名前です。Mojo の所有権制約で自己参照構造体が作れないため、全ノードを 1 つの構造体に集約する設計になり、その結果としてテープという概念と自然に一致しました。

演算(加算・乗算など)を呼ぶたびに、Tape は新しいノードを 4 つの配列に追記します。

配列

内容

data[i]

ノード i の計算値

grad[i]

ノード i の勾配(逆伝播後)

children[i]

子ノードの ID 列

local_grads[i]

各子ノードへの局所勾配

25.6.2.1. 具体例: c = a + b

var a = t.leaf(2.0)   # ノード 0: data=2.0
var b = t.leaf(3.0)   # ノード 1: data=3.0
var c = t.add(a, b)   # ノード 2: data=5.0, children=[0,1], local_grads=[1.0,1.0]

テープの内部状態はこうなります。

index

data

grad

children

local_grads

0

2.0

0.0

[]

[]

1

3.0

0.0

[]

[]

2

5.0

0.0

[0, 1]

[1.0, 1.0]

加算の局所勾配はどちらも 1.0 です(∂(a+b)/∂a = 1∂(a+b)/∂b = 1)。

乗算 c = a * b の場合は local_grads = [b.data, a.data]∂(a*b)/∂a = b∂(a*b)/∂b = a)になります。

25.6.2.2. 逆伝播アルゴリズム

backward(tape, root) はルートノードから勾配を伝播します。

diagram
  1. clear_grads() で全勾配をゼロにリセット

  2. トポロジカルソートで「子より親が後」の順に並べる

  3. ルートの勾配を 1.0 にセット(∂loss/∂loss = 1

  4. 逆順に辿りながら child.grad += local_grad × node.grad を累積

連鎖律(chain rule)の実装そのものです。

25.6.2.3. sum_nodes で sum() を使わない理由

sum_nodesxs: List[Int] を受け取りますが、この Int数値ではなくノード IDです。

注釈

xs の中身はたとえば [42, 57, 63] のような整数ですが、これは「テープの 42 番目・57 番目・63 番目のノード」を指すインデックスです。 sum(xs) を使うと 42 + 57 + 63 = 162 というインデックスの和が計算されてしまい、まったく意味が違います。

self.add(a, b) を繰り返す理由は、単に数値を足すのではなく、テープに加算ノードを追記するためです。

xs = [42, 57, 63]  (ノード ID)

step1: s = self.add(42, 57)  → テープにノード 100 を追記(children=[42,57])
step2: s = self.add(100, 63) → テープにノード 101 を追記(children=[100,63])
return 101                    (総和を表すノードの ID)

このノードの連鎖があることで、backward() 時に勾配が各子ノードへ正しく伝播できます。sum() で単純に足すと計算グラフが作られず、逆伝播できません。


25.7. B4 — ハイパーパラメータと重み(b04_state_dict.mojo)

25.7.1. ソースコード

# B4: Model hyperparameters, weight matrices of Tape nodes, flat param list.

from std.random import randn_float64

from b03_value import NodeId, Tape


struct HyperParams:
    var n_layer: Int
    var n_embd: Int
    var block_size: Int
    var n_head: Int

    def __init__(out self, n_layer: Int, n_embd: Int, block_size: Int, n_head: Int):
        self.n_layer = n_layer
        self.n_embd = n_embd
        self.block_size = block_size
        self.n_head = n_head

    def head_dim(self) -> Int:
        return self.n_embd // self.n_head


struct StateDict:
    var wte: List[List[Int]]
    var wpe: List[List[Int]]
    var lm_head: List[List[Int]]
    var attn_wq: List[List[List[Int]]]
    var attn_wk: List[List[List[Int]]]
    var attn_wv: List[List[List[Int]]]
    var attn_wo: List[List[List[Int]]]
    var mlp_fc1: List[List[List[Int]]]
    var mlp_fc2: List[List[List[Int]]]

    def __init__(
        out self,
        wte: List[List[Int]],
        wpe: List[List[Int]],
        lm_head: List[List[Int]],
        attn_wq: List[List[List[Int]]],
        attn_wk: List[List[List[Int]]],
        attn_wv: List[List[List[Int]]],
        attn_wo: List[List[List[Int]]],
        mlp_fc1: List[List[List[Int]]],
        mlp_fc2: List[List[List[Int]]],
    ):
        self.wte = wte.copy()
        self.wpe = wpe.copy()
        self.lm_head = lm_head.copy()
        self.attn_wq = attn_wq.copy()
        self.attn_wk = attn_wk.copy()
        self.attn_wv = attn_wv.copy()
        self.attn_wo = attn_wo.copy()
        self.mlp_fc1 = mlp_fc1.copy()
        self.mlp_fc2 = mlp_fc2.copy()


def matrix_rand(mut t: Tape, nout: Int, nin: Int, std: Float64) -> List[List[Int]]:
    var m = List[List[Int]]()
    for _ in range(nout):
        var row = List[Int]()
        for _ in range(nin):
            row.append(t.leaf(randn_float64(0.0, std)))
        m.append(row^)
    return m^


def init_state_dict(mut t: Tape, vocab_size: Int, hp: HyperParams, std: Float64) -> StateDict:
    var wte = matrix_rand(t, vocab_size, hp.n_embd, std)
    var wpe = matrix_rand(t, hp.block_size, hp.n_embd, std)
    var lm_head = matrix_rand(t, vocab_size, hp.n_embd, std)
    var attn_wq = List[List[List[Int]]]()
    var attn_wk = List[List[List[Int]]]()
    var attn_wv = List[List[List[Int]]]()
    var attn_wo = List[List[List[Int]]]()
    var mlp_fc1 = List[List[List[Int]]]()
    var mlp_fc2 = List[List[List[Int]]]()
    for _ in range(hp.n_layer):
        attn_wq.append(matrix_rand(t, hp.n_embd, hp.n_embd, std))
        attn_wk.append(matrix_rand(t, hp.n_embd, hp.n_embd, std))
        attn_wv.append(matrix_rand(t, hp.n_embd, hp.n_embd, std))
        attn_wo.append(matrix_rand(t, hp.n_embd, hp.n_embd, std))
        mlp_fc1.append(matrix_rand(t, 4 * hp.n_embd, hp.n_embd, std))
        mlp_fc2.append(matrix_rand(t, hp.n_embd, 4 * hp.n_embd, std))
    return StateDict(
        wte^,
        wpe^,
        lm_head^,
        attn_wq^,
        attn_wk^,
        attn_wv^,
        attn_wo^,
        mlp_fc1^,
        mlp_fc2^,
    )


def _append_matrix_flat(mut ps: List[NodeId], mat: List[List[Int]]):
    for i in range(len(mat)):
        var ncol = len(mat[i])
        for j in range(ncol):
            ps.append(mat[i][j])


def flatten_params(sd: StateDict) -> List[NodeId]:
    var ps = List[NodeId]()
    _append_matrix_flat(ps, sd.wte)
    _append_matrix_flat(ps, sd.wpe)
    _append_matrix_flat(ps, sd.lm_head)
    for li in range(len(sd.attn_wq)):
        _append_matrix_flat(ps, sd.attn_wq[li])
        _append_matrix_flat(ps, sd.attn_wk[li])
        _append_matrix_flat(ps, sd.attn_wv[li])
        _append_matrix_flat(ps, sd.attn_wo[li])
        _append_matrix_flat(ps, sd.mlp_fc1[li])
        _append_matrix_flat(ps, sd.mlp_fc2[li])
    return ps^

25.7.2. 解説

25.7.2.1. HyperParams

var hp = HyperParams(1, 16, 16, 4)
# n_layer=1, n_embd=16, block_size=16, n_head=4

GPT の形状を決めるパラメータです。head_dim() メソッドは n_embd ÷ n_head(各アテンションヘッドの次元数)を返します。

25.7.2.2. StateDict — 重みの束

重み行列の各要素はテープ上のノード IDInt)として保存されています。つまり wte[0][3] は「単語 0 の埋め込みベクトルの 3 番目の要素に対応するノードの ID」です。

diagram

25.7.2.3. matrix_rand — 重みの初期化

row.append(t.leaf(randn_float64(0.0, std)))

from std.random import randn_float64 でインポートした randn_float64(mean, std) を使い、Python 版の random.gauss(0, std) と同等の正規分布乱数で重みを初期化しています。

25.7.2.4. flatten_params

Adam 更新ループで全パラメータを一括処理するため、行列の入れ子をフラットな List[NodeId] に展開します。


25.8. B5a — 演算ユーティリティ(b05_ops.mojo)

25.8.1. ソースコード

# B5a: linear, rmsnorm, softmax on Tape nodes (microgpt.py helpers).

from std.math import sqrt

from b03_value import Tape, NodeId


def linear(mut t: Tape, x: List[Int], w: List[List[Int]]) -> List[Int]:
    var out = List[Int]()
    for r in range(len(w)):
        var acc = t.mul(x[0], w[r][0])
        for c in range(1, len(x)):
            acc = t.add(acc, t.mul(x[c], w[r][c]))
        out.append(acc)
    return out^


def dot(mut t: Tape, a: List[Int], b: List[Int]) -> NodeId:
    var acc = t.mul(a[0], b[0])
    for i in range(1, len(a)):
        acc = t.add(acc, t.mul(a[i], b[i]))
    return acc


def rmsnorm(mut t: Tape, x: List[Int]) -> List[Int]:
    var sq = List[Int]()
    for i in range(len(x)):
        sq.append(t.mul(x[i], x[i]))
    var ms = t.scale(t.sum_nodes(sq), 1.0 / Float64(len(x)))
    var inv_rms = t.pow_const(t.add(ms, t.leaf(1e-5)), -0.5)
    var out = List[Int]()
    for i in range(len(x)):
        out.append(t.mul(x[i], inv_rms))
    return out^


def softmax(mut t: Tape, logits: List[Int]) -> List[Int]:
    var m = t.node_data(logits[0])
    for i in range(1, len(logits)):
        var v = t.node_data(logits[i])
        m = m if m > v else v
    var exps = List[Int]()
    for i in range(len(logits)):
        exps.append(t.exp_(t.sub(logits[i], t.leaf(m))))
    var s = t.sum_nodes(exps)
    var out = List[Int]()
    for i in range(len(exps)):
        out.append(t.div(exps[i], s))
    return out^


def scaled_attention_logits(
    mut t: Tape, q_h: List[Int], k_rows: List[List[Int]], head_dim: Int
) -> List[Int]:
    var logits = List[Int]()
    var scale = 1.0 / sqrt(Float64(head_dim))
    for tk in range(len(k_rows)):
        logits.append(t.scale(dot(t, q_h, k_rows[tk]), scale))
    return logits^

25.8.2. 解説

すべての関数は Tape 上に演算ノードを追加しながら、ノード ID のリストを入出力します。

25.8.2.1. linear(線形変換)

\text{out}[r] = \sum_{c=0}^{C-1} W[r][c] \cdot x[c]

out[r] = x[0]*w[r][0] + x[1]*w[r][1] + ... + x[C-1]*w[r][C-1]

行列 w(形状 n_\text{out} \times n_\text{in})の r 行目と入力ベクトル x(長さ C = n_\text{in})の内積です。行列とベクトルの積 \mathbf{out} = W\mathbf{x} を1行ずつ計算しています。Python の sum(wi*xi ...) に相当します。

25.8.2.2. rmsnorm(RMS 正規化)

LayerNorm の簡略版です。各要素を「二乗平均の平方根」で割って正規化します。

\text{rms} = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x[i]^2 + \varepsilon}, \qquad \text{out}[i] = \frac{x[i]}{\text{rms}}

rms    = sqrt( (x[0]² + x[1]² + ... + x[n-1]²) / n  +  ε )
out[i] = x[i] / rms

pow_const(node, -0.5)(\text{rms}^2)^{-0.5} = 1/\text{rms} を計算し、これを全要素に掛けています。

25.8.2.3. softmax

logits を確率分布に変換します。数値安定化のため最大値 m を引いてから exp を取ります(値が大きすぎて exp がオーバーフローするのを防ぐ)。

p[i] = \frac{\exp(\text{logits}[i] - m)}{\sum_j \exp(\text{logits}[j] - m)}, \qquad m = \max_j \text{logits}[j]

m    = max(logits[0], logits[1], ..., logits[V-1])
e[i] = exp(logits[i] - m)
p[i] = e[i] / (e[0] + e[1] + ... + e[V-1])
var m = t.node_data(logits[0])         # 現在の値(.data)を取り出し
for i in range(1, len(logits)):
    var v = t.node_data(logits[i])
    m = m if m > v else v              # 最大値を求める

ここで t.node_data(...) を呼んでいるのは、最大値は定数扱いで微分不要なためです(ソフトマックスの安定化はグラフに乗せない)。

25.8.2.4. scaled_attention_logits

Query と Key の内積にスケール 1/\sqrt{d_k} を掛けます。

\text{logits}[t] = \frac{\mathbf{q} \cdot \mathbf{k}_t}{\sqrt{d_k}}

logits[t] = (q[0]*k_t[0] + q[1]*k_t[1] + ... + q[d-1]*k_t[d-1]) / sqrt(d_k)

スケールしないと内積が d_k に比例して大きくなり、softmax 後の分布が一点に集中してしまいます。


25.9. B5b — GPT 順伝播(b05_gpt.mojo)

25.9.1. ソースコード

# B5b: One-step GPT forward with KV cache (microgpt.py gpt).

from b03_value import Tape
from b04_state_dict import HyperParams, StateDict
from b05_ops import linear, rmsnorm, softmax, scaled_attention_logits


def row_embedding(mat: List[List[Int]], idx: Int) -> List[Int]:
    var out = List[Int]()
    var n = len(mat[idx])
    for i in range(n):
        out.append(mat[idx][i])
    return out^


def slice_vec(xs: List[Int], start: Int, width: Int) -> List[Int]:
    var out = List[Int]()
    for j in range(width):
        out.append(xs[start + j])
    return out^


def head_slice_from_cache(cache: List[List[Int]], tix: Int, head_start: Int, head_dim: Int) -> List[Int]:
    var out = List[Int]()
    for j in range(head_dim):
        out.append(cache[tix][head_start + j])
    return out^


def attn_head(
    mut t: Tape,
    q_h: List[Int],
    k_rows: List[List[Int]],
    v_rows: List[List[Int]],
    head_dim: Int,
) -> List[Int]:
    var logits = scaled_attention_logits(t, q_h, k_rows, head_dim)
    var w = softmax(t, logits^)
    var head_out = List[Int]()
    for j in range(head_dim):
        var acc = t.mul(w[0], v_rows[0][j])
        for tk in range(1, len(v_rows)):
            acc = t.add(acc, t.mul(w[tk], v_rows[tk][j]))
        head_out.append(acc)
    return head_out^


def concat_heads(heads: List[List[Int]]) -> List[Int]:
    var out = List[Int]()
    for hi in range(len(heads)):
        var nh = len(heads[hi])
        for j in range(nh):
            out.append(heads[hi][j])
    return out^


def append_kv_row(mut cache: List[List[Int]], row: List[Int]):
    cache.append(row.copy())


def gpt_forward(
    mut t: Tape,
    sd: StateDict,
    hp: HyperParams,
    token_id: Int,
    pos_id: Int,
    mut keys: List[List[List[Int]]],
    mut vals: List[List[List[Int]]],
) -> List[Int]:
    var hd = hp.head_dim()
    var tok = row_embedding(sd.wte, token_id)
    var pos = row_embedding(sd.wpe, pos_id)
    var x = List[Int]()
    for i in range(len(tok)):
        x.append(t.add(tok[i], pos[i]))
    x = rmsnorm(t, x^)

    for li in range(hp.n_layer):
        var x_res = List[Int]()
        for j in range(len(x)):
            x_res.append(x[j])

        var x_ln = rmsnorm(t, x^)
        var q = linear(t, x_ln, sd.attn_wq[li])
        var k = linear(t, x_ln, sd.attn_wk[li])
        var v = linear(t, x_ln, sd.attn_wv[li])
        append_kv_row(keys[li], k^)
        append_kv_row(vals[li], v^)

        var heads = List[List[Int]]()
        for h in range(hp.n_head):
            var hs = h * hd
            var q_h = slice_vec(q, hs, hd)
            var k_h = List[List[Int]]()
            var v_h = List[List[Int]]()
            for tix in range(len(keys[li])):
                k_h.append(head_slice_from_cache(keys[li], tix, hs, hd))
                v_h.append(head_slice_from_cache(vals[li], tix, hs, hd))
            var hpiece = attn_head(t, q_h, k_h, v_h, hd)
            heads.append(hpiece^)

        var merged = concat_heads(heads)
        var x_attn = linear(t, merged, sd.attn_wo[li])
        var x2 = List[Int]()
        for j in range(len(x_attn)):
            x2.append(t.add(x_attn[j], x_res[j]))
        x = x2^

        var xr2 = List[Int]()
        for j in range(len(x)):
            xr2.append(x[j])
        var x_mlp_in = rmsnorm(t, x^)
        var h1 = linear(t, x_mlp_in, sd.mlp_fc1[li])
        var h2 = List[Int]()
        for j in range(len(h1)):
            h2.append(t.relu(h1[j]))
        var x_mlp = linear(t, h2, sd.mlp_fc2[li])
        var x3 = List[Int]()
        for j in range(len(x_mlp)):
            x3.append(t.add(x_mlp[j], xr2[j]))
        x = x3^

    return linear(t, x, sd.lm_head)


def gpt_forward_embed_only(mut t: Tape, sd: StateDict, token_id: Int, pos_id: Int) -> List[Int]:
    var tok = row_embedding(sd.wte, token_id)
    var pos = row_embedding(sd.wpe, pos_id)
    var x = List[Int]()
    for i in range(len(tok)):
        x.append(t.add(tok[i], pos[i]))
    x = rmsnorm(t, x^)
    return linear(t, x, sd.lm_head)

25.9.2. 解説

gpt_forward が1ステップの GPT 計算です。1トークン分の入力を受け取り、次トークンの確率を示す logits ベクトルを返します。

25.9.2.1. 埋め込み層

var tok = row_embedding(sd.wte, token_id)   # トークン埋め込み
var pos = row_embedding(sd.wpe, pos_id)     # 位置埋め込み
var x = ...  # tok + pos の要素和
x = rmsnorm(t, x^)                          # 初期正規化

row_embedding は行列の 1 行(= ノード ID の列)をコピーして返します。mat[idx] を直接返すのではなく mat[idx][i]要素ごとに読むのは、行リストのムーブによる行列破壊を防ぐためです(後述)。

25.9.2.2. Transformer 層

diagram

25.9.2.3. KV キャッシュ

推論時も学習時も、過去のトークンで計算した Key/Value をリストに蓄積します。位置 t での Attention は位置 0〜t の K/V すべてに注目できます(因果的 Attention)。

append_kv_row(keys[li], k^)   # この位置の Key をキャッシュに追加
append_kv_row(vals[li], v^)   # この位置の Value をキャッシュに追加
for tix in range(len(keys[li])):
    k_h.append(head_slice_from_cache(keys[li], tix, hs, hd))
    ...

head_slice_from_cachecache[tix] を丸ごと渡さず cache[tix][j + off] と要素アクセスしているのは、行リストのムーブを避けるためです。

25.9.2.4. マルチヘッドアテンション

for h in range(hp.n_head):
    var hs = h * hd                          # このヘッドの開始インデックス
    var q_h = slice_vec(q, hs, hd)           # Q をヘッドごとに切り出す
    # K/V も同様に head_slice_from_cache で切り出す
    var hpiece = attn_head(t, q_h, k_h, v_h, hd)
    heads.append(hpiece^)
var merged = concat_heads(heads)             # ヘッドを連結して n_embd 次元に戻す

n_head 個のヘッドを独立に計算して最後に連結します。


25.10. B6 — 学習ステップ(b06_train.mojo)

25.10.1. ソースコード

# B6: One document training step — forward, cross-entropy loss, backward, Adam.

from std.math import sqrt

from b03_value import Tape, backward
from b04_state_dict import HyperParams, StateDict
from b05_gpt import gpt_forward
from b05_ops import softmax


def min_int(a: Int, b: Int) -> Int:
    return a if a < b else b


def init_kv_cache(n_layer: Int, mut keys: List[List[List[Int]]], mut vals: List[List[List[Int]]]):
    for _ in range(n_layer):
        keys.append(List[List[Int]]())
        vals.append(List[List[Int]]())


def loss_on_document(
    mut t: Tape,
    sd: StateDict,
    hp: HyperParams,
    tokens: List[Int],
    mut keys: List[List[List[Int]]],
    mut vals: List[List[List[Int]]],
) -> Int:
    var n = min_int(hp.block_size, len(tokens) - 1)
    var losses = List[Int]()
    for pos_id in range(n):
        var tid = tokens[pos_id]
        var target = tokens[pos_id + 1]
        var logits = gpt_forward(t, sd, hp, tid, pos_id, keys, vals)
        var probs = softmax(t, logits)
        var nll = t.neg(t.log_(probs[target]))
        losses.append(nll)
    return t.scale(t.sum_nodes(losses), 1.0 / Float64(n))


def adam_update(
    mut tape: Tape,
    params: List[Int],
    mut m: List[Float64],
    mut v: List[Float64],
    step_idx: Int,
    total_steps: Int,
):
    var lr = 0.01
    var beta1 = 0.85
    var beta2 = 0.99
    var eps = 1e-8
    var lr_t = lr * (1.0 - Float64(step_idx) / Float64(total_steps))
    var tstep = step_idx + 1
    var b1_corr = 1.0
    var b2_corr = 1.0
    for _ in range(tstep):
        b1_corr *= beta1
        b2_corr *= beta2
    for i in range(len(params)):
        var pid = params[i]
        var g = tape.grad_at(pid)
        m[i] = beta1 * m[i] + (1.0 - beta1) * g
        v[i] = beta2 * v[i] + (1.0 - beta2) * g * g
        var m_hat = m[i] / (1.0 - b1_corr)
        var v_hat = v[i] / (1.0 - b2_corr)
        var newv = tape.node_data(pid) - lr_t * m_hat / (sqrt(v_hat) + eps)
        tape.set_data(pid, newv)
        tape.set_grad(pid, 0.0)

25.10.2. 解説

25.10.2.1. loss_on_document — 1文書の損失

1文書分のトークン列に対して損失を計算します。

for pos_id in range(n):
    var tid    = tokens[pos_id]       # 現在のトークン
    var target = tokens[pos_id + 1]   # 次のトークン(正解ラベル)
    var logits = gpt_forward(...)
    var probs  = softmax(t, logits)
    var nll    = t.neg(t.log_(probs[target]))  # 負の対数尤度
    losses.append(nll)
return t.scale(t.sum_nodes(losses), 1.0 / Float64(n))  # 平均損失

正解トークン target の確率 probs[target] を取り出し、その対数の符号反転(NLL)を損失とします。予測が正確なほど確率は 1 に近づき、NLL は 0 に近づきます。

25.10.2.2. adam_update — Adam 最適化

var lr_t = lr * (1.0 - Float64(step_idx) / Float64(total_steps))  # 線形学習率減衰
m[i] = beta1 * m[i] + (1 - beta1) * g        # 第1モーメント(勾配の移動平均)
v[i] = beta2 * v[i] + (1 - beta2) * g * g    # 第2モーメント(勾配二乗の移動平均)
var m_hat = m[i] / (1 - b1_corr)             # バイアス補正
var v_hat = v[i] / (1 - b2_corr)
var newv = tape.node_data(pid) - lr_t * m_hat / (sqrt(v_hat) + eps)
tape.set_data(pid, newv)    # 重みの値を更新
tape.set_grad(pid, 0.0)     # 勾配をリセット

Adam はモメンタム(第1モーメント)と RMSProp(第2モーメント)を組み合わせた最適化手法です。初期ステップのバイアスを 1 - beta^t で補正しています。


25.11. B7 — 推論ヘルパ(b07_infer.mojo)

25.11.1. ソースコード

# B7: Temperature-scaled softmax and greedy token pick (inference helpers).

from std.random import random_float64

from b03_value import Tape
from b05_ops import softmax


def logits_with_temperature(mut t: Tape, logits: List[Int], temperature: Float64) -> List[Int]:
    var scaled = List[Int]()
    var inv_t = 1.0 / temperature
    for i in range(len(logits)):
        scaled.append(t.scale(logits[i], inv_t))
    return softmax(t, scaled)


def greedy_argmax(probs: List[Int], tape: Tape) -> Int:
    var best = 0
    var bestv = tape.node_data(probs[0])
    for i in range(1, len(probs)):
        var v = tape.node_data(probs[i])
        if v > bestv:
            bestv = v
            best = i
    return best


def sample_from_probs(probs: List[Int], tape: Tape) -> Int:
    var r = random_float64()
    var cumsum = 0.0
    for i in range(len(probs)):
        cumsum += tape.node_data(probs[i])
        if r < cumsum:
            return i
    return len(probs) - 1


def should_stop(token_id: Int, bos_id: Int) -> Bool:
    return token_id == bos_id

25.11.2. 解説

25.11.2.1. logits_with_temperature — 温度スケーリング

var inv_t = 1.0 / temperature
for i in range(len(logits)):
    scaled.append(t.scale(logits[i], inv_t))
return softmax(t, scaled)

temperature が低い(例 0.5)ほど logits の差が拡大し、確率分布が尖って確定的な出力になります。temperature=1.0 では logits をそのまま使います。

diagram

25.11.2.2. sample_from_probs — 累積分布サンプリング

var r = random_float64()          # [0, 1) の一様乱数
var cumsum = 0.0
for i in range(len(probs)):
    cumsum += tape.node_data(probs[i])
    if r < cumsum:                # 累積確率が乱数を超えた時点で選択
        return i

確率に比例した頻度でトークンを選びます。Python の random.choices(range(n), weights=probs) と同等です。

diagram

25.12. main.mojo — 全体を組み合わせる

25.12.1. ソースコード

# microgpt.py と同等の処理を Mojo で実装したメインファイル。
# B1〜B7 モジュールを組み合わせて学習と推論を行う。

from std.random import random_float64, seed

from b01_dataset import load_docs_from_text
from b02_tokenizer import build_tokenizer, encode_doc
from b03_value import Tape, backward
from b04_state_dict import HyperParams, init_state_dict, flatten_params
from b05_gpt import gpt_forward
from b06_train import init_kv_cache, loss_on_document, adam_update
from b07_infer import logits_with_temperature, sample_from_probs, should_stop


def main() raises:
    seed(42)  # 再現性のため乱数シードを固定(Python の random.seed(42) 相当)

    # -------------------------------------------------------------------------
    # データセット
    # -------------------------------------------------------------------------
    var f = open("input.txt", "r")
    var text = f.read()
    f.close()
    var docs = load_docs_from_text(text)
    # Fisher-Yates シャッフル
    for i in range(len(docs) - 1, 0, -1):
        var j = Int(random_float64() * Float64(i + 1))
        var tmp = docs[i]
        docs[i] = docs[j]
        docs[j] = tmp
    print("num docs:", len(docs))

    # -------------------------------------------------------------------------
    # トークナイザー
    # -------------------------------------------------------------------------
    var tok = build_tokenizer(docs)
    print("vocab size:", tok.vocab_size())

    # -------------------------------------------------------------------------
    # ハイパーパラメータ
    # -------------------------------------------------------------------------
    # n_layer=1, n_embd=16, block_size=16, n_head=4
    var hp = HyperParams(1, 16, 16, 4)

    # -------------------------------------------------------------------------
    # テープとパラメータの初期化
    # -------------------------------------------------------------------------
    var t = Tape()
    var sd = init_state_dict(t, tok.vocab_size(), hp, 0.08)
    var params = flatten_params(sd)
    print("num params:", len(params))

    # -------------------------------------------------------------------------
    # Adam バッファ
    # -------------------------------------------------------------------------
    var m_buf = List[Float64]()
    var v_buf = List[Float64]()
    for _ in range(len(params)):
        m_buf.append(0.0)
        v_buf.append(0.0)

    # -------------------------------------------------------------------------
    # 学習ループ
    # -------------------------------------------------------------------------
    var num_steps = 1000
    for step in range(num_steps):
        var doc = docs[step % len(docs)]

        var keys = List[List[List[Int]]]()
        var vals = List[List[List[Int]]]()
        init_kv_cache(hp.n_layer, keys, vals)

        var tokens = encode_doc(tok, doc)
        var loss_node = loss_on_document(t, sd, hp, tokens, keys, vals)
        backward(t, loss_node)
        adam_update(t, params, m_buf, v_buf, step, num_steps)
        print("step", step + 1, "/", num_steps, "| loss", t.node_data(loss_node))

    # -------------------------------------------------------------------------
    # 推論
    # -------------------------------------------------------------------------
    print("\n--- inference ---")
    for si in range(20):
        var infer_keys = List[List[List[Int]]]()
        var infer_vals = List[List[List[Int]]]()
        init_kv_cache(hp.n_layer, infer_keys, infer_vals)

        var token_id = tok.bos
        var result = String("")
        for pos_id in range(hp.block_size):
            var logits = gpt_forward(t, sd, hp, token_id, pos_id, infer_keys, infer_vals)
            var probs = logits_with_temperature(t, logits^, 0.5)
            token_id = sample_from_probs(probs, t)
            if should_stop(token_id, tok.bos):
                break
            result += tok.uchars[token_id]
        print("sample", si + 1, ":", result)

25.12.2. 解説

25.12.2.1. データ読み込みとシャッフル

var f = open("input.txt", "r")
var text = f.read()
f.close()
var docs = load_docs_from_text(text)
# Fisher-Yates シャッフル
for i in range(len(docs) - 1, 0, -1):
    var j = Int(random_float64() * Float64(i + 1))
    var tmp = docs[i]
    docs[i] = docs[j]
    docs[j] = tmp

Fisher-Yates シャッフルは末尾から先頭に向かって走査し、各要素をランダムな位置と入れ替えます。学習順序をランダム化することで、特定の文書の偏りを防ぎます。

25.12.2.2. 学習ループ

for step in range(num_steps):
    var doc = docs[step % len(docs)]        # docs を順番に使い回す

    var keys = List[List[List[Int]]]()
    var vals = List[List[List[Int]]]()
    init_kv_cache(hp.n_layer, keys, vals)   # KV キャッシュをリセット

    var tokens = encode_doc(tok, doc)
    var loss_node = loss_on_document(t, sd, hp, tokens, keys, vals)
    backward(t, loss_node)                  # 勾配を計算
    adam_update(t, params, m_buf, v_buf, step, num_steps)  # 重みを更新

各ステップで KV キャッシュを初期化するのが重要です。KV キャッシュは1文書内の過去トークンを蓄積するもので、ステップをまたいで持ち越してはいけません。

25.12.2.3. 推論ループ

var token_id = tok.bos               # BOS から生成スタート
for pos_id in range(hp.block_size):
    var logits = gpt_forward(...)
    var probs = logits_with_temperature(t, logits^, 0.5)
    token_id = sample_from_probs(probs, t)
    if should_stop(token_id, tok.bos):   # BOS が出たら終了
        break
    result += tok.uchars[token_id]

BOS トークンを起点に1文字ずつ生成し、再び BOS が出たら(= 文の末尾を予測したら)停止します。block_size は生成の上限長になります。


25.13. Python 版との違いと実装上の工夫

25.13.1. なぜ Value 型ではなく Tape 型なのか

Python 版の Value はクラスのインスタンスで、参照によって計算グラフを形成します。

# Python: オブジェクト参照でグラフを作る
class Value:
    def __init__(self, data, children=(), local_grads=()):
        self._children = children      # 別の Value オブジェクトへの参照

Mojo でこれを直接再現しようとすると問題が起きます。

struct Value:
    var _children: List[Value]  # ❌ Value が Value を含む → 無限サイズ

Mojo の struct は固定サイズでなければならないため、自己参照型は定義できません。また、所有権ルールにより List[Value] 内の要素を mut 参照で変更することも困難です。

解決策が テープ(Tape)パターンです。

struct Tape:
    var data:        List[Float64]     # 全ノードの値
    var grad:        List[Float64]     # 全ノードの勾配
    var children:    List[List[Int]]   # 各ノードの子インデックス
    var local_grads: List[List[Float64]]

ノード同士はインデックス(Int = NodeId)で参照します。Tape は複数の配列の束(SoA: Structure of Arrays)として表現されているので、サイズが固定でなくとも問題ありません。Python で言えば「グラフをオブジェクト参照で持つ」代わりに「グラフをインデックステーブルで持つ」イメージです。

25.13.2. inout/owned の廃止への対応

Mojo 0.26 では inoutowned キーワードが完全に削除されました。

旧構文

新構文

用途

fn __init__(inout self, ...)

def __init__(out self, ...)

コンストラクタ

fn method(inout self, ...)

def method(mut self, ...)

変更ありメソッド

fn func(inout x: T, ...)

def func(mut x: T, ...)

変更あり引数

fn func(owned x: T)

def func(x: T) + 呼び側で x^

所有権の転送

owned の代替として、呼び出し側で x^ を使いムーブ渡しします。

return Tokenizer(chars^, bos)   # chars の所有権をコンストラクタに渡す

25.13.3. List[T] の所有権

List[T]コピー不可能ImplicitlyCopyable でない)です。別の変数に代入すると所有権が移動します。

var row = mat[0]       # ❌ mat から行をコピーしようとしてエラー
var row = mat[0][j]    # ✅ 要素(Int)はコピー可能

ネストしたリスト操作では要素インデックスで個別アクセスするのが基本方針です。children のような List[List[Int]] を取り出すときは .copy() が必要です。

var ch = tape.children[v].copy()    # .copy() で明示的にコピー

25.13.4. 転送演算子 ^ と変数の再代入

x = rmsnorm(t, x^)    # x を rmsnorm にムーブし、結果を x に代入

x^ は「x の所有権をここで放棄する」という宣言です。これ以降 x は未初期化になりますが、右辺の戻り値が即座に x に代入されるため問題ありません。

ただし同じスコープで x を複数の関数に渡す場合はムーブできないため、rmsnorm の結果を新しい変数名で受けています。

var x_ln = rmsnorm(t, x^)   # x はムーブ後に無効化
# x_res(別途保存した残差)を使う

25.13.5. 初期化方法の違い

項目

Python 版

Mojo 版

重みの初期化

random.gauss(0, std)

randn_float64(0.0, std)

シード

random.seed(42) で固定

seed(42) で固定

docs シャッフル

random.shuffle(docs)

Fisher-Yates + random_float64

input.txt の取得

自動ダウンロード

手動で配置(要事前準備)

seed(42)main.mojo の冒頭で呼ぶことで、重みの初期化・シャッフル・推論サンプリングすべてに同じシードが適用され、Python 版と同様に再現性が確保されます。


25.14. まとめ

  • B1〜B7 は src/part3/microgpt_mojo/b0*.mojo に対応し、b*_test.mojo でブロック単位に検証できます

  • Tape パターンは Mojo の所有権制約のもとで計算グラフを安全に扱うための鍵です

  • Python の参照ベースの設計を、Mojo ではインデックスベースの SoA に置き換えることで、型安全性を保ちながら同等の自動微分を実現しています

  • inout/owned の廃止、List[T] の所有権制約、^ 転送演算子など、Mojo 固有のルールに沿った書き方が全体を通して現れています

25.15. 次に読む章

26 章MAX を導入する意味)へ進みます。