28. microgpt Mojo 版を MAX で高速化する

28.1. この章で学ぶこと

  • Mojo 版 microgpt(Tape ベース)の学習はそのまま維持し、推論だけ MAX に切り替える方法

  • Mojo の Tape ノード値を Python interop で numpy 配列に変換する方法

  • MAX Graph で GPT 推論ループを再構築する方法

  • microgpt_mojomicrogpt_mojo_max を比較して結果を確認する方法


28.2. 方針:学習と推論を分離する

Mojo の Tape 実装はスカラーループのため推論が遅いですが、自動微分(学習)には最適です。 MAX Engine には学習(勾配計算)API がないため、学習は Tape のままにして推論だけ MAX に切り替えます。

diagram

フェーズ

担当

microgpt_mojo との違い

学習

Mojo Tape + Adam(B1〜B6)

なし

重みの抽出

b08_max_infer.mojo(新規)

新規

推論

max_infer_helper.py(MAX Graph)

新規(高速)


28.3. ディレクトリ構成

src/part3/microgpt_mojo_max/
├── b01_dataset.mojo      ← microgpt_mojo からコピー(変更なし)
├── b02_tokenizer.mojo    ← 同上
├── b03_value.mojo        ← 同上
├── b04_state_dict.mojo   ← 同上
├── b05_ops.mojo          ← 同上
├── b05_gpt.mojo          ← 同上
├── b06_train.mojo        ← 同上
├── b08_max_infer.mojo    ← 新規:重み抽出 + MAX 呼び出し
├── max_infer_helper.py   ← 新規:MAX Graph 定義・推論ループ
├── main.mojo             ← 変更:推論部分を b08 に差し替え
└── input.txt             ← microgpt_mojo からシンボリックリンク

28.4. B8: 重みの抽出と MAX 呼び出し(b08_max_infer.mojo)

b08_max_infer.mojo は Mojo と MAX(Python)の橋渡しをする唯一の新規 Mojo ファイルです。

# B8: MAX Graph を使った高速推論(Python interop 経由)。
# 学習済みの Tape ノード値を numpy 配列に変換し、
# max_infer_helper.py の MAX Graph 推論に渡す。

from std.python import Python, PythonObject

from b03_value import Tape
from b04_state_dict import HyperParams, StateDict


def _matrix_to_numpy(
    t: Tape,
    mat: List[List[Int]],
    np: PythonObject,
) raises -> PythonObject:
    """List[List[NodeId]] を float32 ndarray に変換する。
    mat[i][j] は Tape 上のノード ID(Int)。t.node_data() で値を取り出す。"""
    var rows = len(mat)
    var cols = len(mat[0])
    var flat = Python.list()
    for i in range(rows):
        for j in range(cols):
            flat.append(t.node_data(mat[i][j]))
    return np.array(flat, dtype="float32").reshape(rows, cols)


def run_max_inference(
    t: Tape,
    sd: StateDict,
    hp: HyperParams,
    uchars: List[String],
    bos: Int,
    n_samples: Int,
    temperature: Float64,
) raises:
    """学習済み重みを MAX Graph に渡して推論を実行する。
    学習に使った Tape は変更しない(推論専用)。"""
    var np = Python.import_module("numpy")

    # ── Tape 上の重みノードを numpy 配列に変換 ──────────────────────
    var weights = Python.dict()
    weights["wte"]     = _matrix_to_numpy(t, sd.wte, np)
    weights["wpe"]     = _matrix_to_numpy(t, sd.wpe, np)
    weights["lm_head"] = _matrix_to_numpy(t, sd.lm_head, np)
    for li in range(hp.n_layer):
        var s = String(li)
        weights["attn_wq_" + s] = _matrix_to_numpy(t, sd.attn_wq[li], np)
        weights["attn_wk_" + s] = _matrix_to_numpy(t, sd.attn_wk[li], np)
        weights["attn_wv_" + s] = _matrix_to_numpy(t, sd.attn_wv[li], np)
        weights["attn_wo_" + s] = _matrix_to_numpy(t, sd.attn_wo[li], np)
        weights["mlp_fc1_" + s] = _matrix_to_numpy(t, sd.mlp_fc1[li], np)
        weights["mlp_fc2_" + s] = _matrix_to_numpy(t, sd.mlp_fc2[li], np)

    # ── ハイパーパラメータを Python dict に ─────────────────────────
    var hp_dict = Python.dict()
    hp_dict["n_layer"]    = hp.n_layer
    hp_dict["n_embd"]     = hp.n_embd
    hp_dict["n_head"]     = hp.n_head
    hp_dict["block_size"] = hp.block_size

    # ── uchars (List[String]) を Python list に ──────────────────────
    var uchars_py = Python.list()
    for i in range(len(uchars)):
        uchars_py.append(uchars[i])

    # ── MAX 推論ヘルパー(Python)を呼び出す ─────────────────────────
    # max_infer_helper.py は同じディレクトリに置く
    var sys = Python.import_module("sys")
    sys.path.insert(0, ".")
    var helper = Python.import_module("max_infer_helper")
    helper.run_inference(weights, hp_dict, uchars_py, bos, n_samples, temperature)

ポイント:

  • _matrix_to_numpy() では t.node_data(mat[i][j]) で Tape ノードの値(Float64)を取り出し、 Python.list() に積んで np.array(...).reshape(rows, cols) で ndarray に変換する

  • キー名は "attn_wq_0", "mlp_fc1_0" のように 名前_層番号 で統一する

  • sys.path.insert(0, ".") で実行ディレクトリを Python パスに追加してから max_infer_helper をインポートする


28.5. MAX Graph の定義(max_infer_helper.py)

MAX Graph のコード(EmbedLayerAttentionLayerMLPLayerbuild_gpt_graphrun_inference)は 純粋な Python ファイルにまとめます。

"""
MAX Graph を使った GPT 推論ヘルパー。
b08_max_infer.mojo から Python interop 経由で呼ばれる。

インストール済み MAX のバージョンに合わせた API:
  - Module クラスは使わない(この版には存在しない)
  - TensorType(DType, shape, DeviceRef.CPU()) — device 引数が必須
  - Weight(name, DType, shape, device)
  - Graph('name', input_types=[...]) + g.inputs[i]
  - InferenceSession(devices=[CPU()])
  - session.load(g, weights_registry={name: ndarray})
  - 出力は Buffer オブジェクト → np.from_dlpack() で変換
  - ops.gather のインデックスは [1] 形状(スカラー不可)
"""

import math
import numpy as np
from max.graph import Graph, TensorType, Weight, DeviceRef, SymbolicDim
from max.graph import ops
from max.dtype import DType
from max import engine
from max.driver import CPU


_CPU_DEV = None  # DeviceRef(グラフ定義用)
_CPU_RT  = None  # CPU driver(実行用)


def _get_cpu():
    global _CPU_DEV, _CPU_RT
    if _CPU_DEV is None:
        _CPU_DEV = DeviceRef.CPU()
        _CPU_RT  = CPU()
    return _CPU_DEV, _CPU_RT


# ─────────────────────────────────────────────────────────────────────────────
# グラフ内ヘルパー関数
# ─────────────────────────────────────────────────────────────────────────────

def _rmsnorm(x, eps: float = 1e-5):
    """x: [n_embd] → [n_embd]"""
    sq  = ops.mul(x, x)
    ms  = ops.mean(sq, axis=0)
    inv = ops.rsqrt(ops.add(ms, eps))
    return ops.mul(x, inv)


def _embed(wte: Weight, wpe: Weight, token_id, pos_id, n_embd: int):
    """gather 埋め込み + 位置埋め込み。インデックスは [1] 形状で渡す。"""
    tok = ops.reshape(ops.gather(wte, token_id, axis=0), [n_embd])
    pos = ops.reshape(ops.gather(wpe, pos_id,   axis=0), [n_embd])
    return _rmsnorm(ops.add(tok, pos))


def _attn_layer(weights, x, k_cache, v_cache, n_head, head_dim, n_embd, li):
    """1層分のマルチヘッドアテンション(KV キャッシュあり)"""
    wq = weights[f"wq_{li}"]
    wk = weights[f"wk_{li}"]
    wv = weights[f"wv_{li}"]
    wo = weights[f"wo_{li}"]

    q = ops.reshape(ops.matmul(wq, ops.reshape(x, [n_embd, 1])), [n_embd])
    k = ops.reshape(ops.matmul(wk, ops.reshape(x, [n_embd, 1])), [n_embd])
    v = ops.reshape(ops.matmul(wv, ops.reshape(x, [n_embd, 1])), [n_embd])

    # KV キャッシュに追記
    k_cache = ops.concat([k_cache, ops.reshape(k, [1, n_embd])], axis=0)
    v_cache = ops.concat([v_cache, ops.reshape(v, [1, n_embd])], axis=0)

    scale = 1.0 / math.sqrt(head_dim)
    heads = []
    for h in range(n_head):
        s = h * head_dim
        # Python スライス構文でヘッドごとに分割
        q_h = q[s : s + head_dim]
        k_h = k_cache[:, s : s + head_dim]   # [seq, head_dim]
        v_h = v_cache[:, s : s + head_dim]   # [seq, head_dim]

        scores = ops.mul(
            ops.reshape(ops.matmul(k_h, ops.reshape(q_h, [head_dim, 1])), [-1]),
            scale,
        )
        w = ops.softmax(scores, axis=0)
        head_out = ops.reshape(
            ops.matmul(ops.reshape(w, [1, -1]), v_h), [head_dim]
        )
        heads.append(head_out)

    merged = ops.concat(heads, axis=0)
    out = ops.reshape(ops.matmul(wo, ops.reshape(merged, [n_embd, 1])), [n_embd])
    return ops.add(out, x), k_cache, v_cache   # 残差接続


def _mlp_layer(weights, x, n_embd, li):
    """2層 MLP(ReLU)。hidden = 4 × n_embd"""
    fc1 = weights[f"fc1_{li}"]
    fc2 = weights[f"fc2_{li}"]
    h   = ops.relu(ops.reshape(ops.matmul(fc1, ops.reshape(x, [-1, 1])), [-1]))
    out = ops.reshape(ops.matmul(fc2, ops.reshape(h, [-1, 1])), [-1])
    return ops.add(out, x)   # 残差接続


# ─────────────────────────────────────────────────────────────────────────────
# グラフ構築
# ─────────────────────────────────────────────────────────────────────────────

def build_gpt_graph(weight_arrays: dict, hp: dict):
    """学習済み ndarray と hp から MAX Graph を返す。

    入力テンソル:
        token_id : int32 [1]
        pos_id   : int32 [1]
        k_cache_0, v_cache_0, ... : float32 [seq, n_embd]
    出力テンソル:
        logits          : float32 [vocab_size]
        new_k_cache_0, new_v_cache_0, ...
    """
    cpu, _ = _get_cpu()

    n_embd      = int(hp["n_embd"])
    n_head      = int(hp["n_head"])
    head_dim    = n_embd // n_head
    n_layer     = int(hp["n_layer"])
    vocab       = int(weight_arrays["wte"].shape[0])
    block_size  = int(weight_arrays["wpe"].shape[0])

    # ── 入力型リスト ─────────────────────────────────────────────
    # KV キャッシュのシーケンス長は可変なので SymbolicDim を使う
    input_types = [
        TensorType(DType.int32,   [1],         cpu),   # token_id
        TensorType(DType.int32,   [1],         cpu),   # pos_id
    ]
    for li in range(n_layer):
        # k と v は同じシーケンス長なので同じ SymbolicDim を使う
        seq = SymbolicDim(f"seq{li}")
        input_types.append(TensorType(DType.float32, [seq, n_embd], cpu))
        input_types.append(TensorType(DType.float32, [seq, n_embd], cpu))

    with Graph("microgpt_infer", input_types=input_types) as g:
        token_id = g.inputs[0]
        pos_id   = g.inputs[1]
        k_caches = [g.inputs[2 + li * 2]     for li in range(n_layer)]
        v_caches = [g.inputs[2 + li * 2 + 1] for li in range(n_layer)]

        # ── Weight ノード(graph 内に固定値として埋め込む) ────────
        wte = Weight("wte",     DType.float32, [vocab,      n_embd], cpu)
        wpe = Weight("wpe",     DType.float32, [block_size, n_embd], cpu)
        lm  = Weight("lm_head", DType.float32, [vocab,      n_embd], cpu)

        w_graph = {}   # str → Weight
        for li in range(n_layer):
            for nm in ["wq", "wk", "wv", "wo"]:
                w_graph[f"{nm}_{li}"] = Weight(
                    f"{nm}_{li}", DType.float32, [n_embd, n_embd], cpu
                )
            # fc1: [4*n_embd, n_embd], fc2: [n_embd, 4*n_embd]
            hidden = weight_arrays[f"mlp_fc1_{li}"].shape[0]
            w_graph[f"fc1_{li}"] = Weight(f"fc1_{li}", DType.float32, [hidden, n_embd], cpu)
            w_graph[f"fc2_{li}"] = Weight(f"fc2_{li}", DType.float32, [n_embd, hidden], cpu)

        # ── 順伝播 ───────────────────────────────────────────────
        x = _embed(wte, wpe, token_id, pos_id, n_embd)

        new_k_caches, new_v_caches = [], []
        for li in range(n_layer):
            x = _rmsnorm(x)
            x, new_k, new_v = _attn_layer(
                w_graph, x, k_caches[li], v_caches[li], n_head, head_dim, n_embd, li
            )
            new_k_caches.append(new_k)
            new_v_caches.append(new_v)

            x = _rmsnorm(x)
            x = _mlp_layer(w_graph, x, n_embd, li)

        logits = ops.reshape(ops.matmul(lm, ops.reshape(x, [n_embd, 1])), [-1])

        g.output(logits, *new_k_caches, *new_v_caches)

    return g


def _build_weights_registry(weight_arrays: dict, hp: dict) -> dict:
    """ndarray dict を weights_registry(名前 → ndarray)に整形する。"""
    n_layer = int(hp["n_layer"])
    reg = {
        "wte":     weight_arrays["wte"],
        "wpe":     weight_arrays["wpe"],
        "lm_head": weight_arrays["lm_head"],
    }
    for li in range(n_layer):
        for nm in ["attn_wq", "attn_wk", "attn_wv", "attn_wo"]:
            short = nm.replace("attn_", "")         # "wq", "wk", ...
            reg[f"{short}_{li}"] = weight_arrays[f"{nm}_{li}"]
        reg[f"fc1_{li}"] = weight_arrays[f"mlp_fc1_{li}"]
        reg[f"fc2_{li}"] = weight_arrays[f"mlp_fc2_{li}"]
    return reg


# ─────────────────────────────────────────────────────────────────────────────
# 推論ループ
# ─────────────────────────────────────────────────────────────────────────────

def run_inference(
    weight_arrays,
    hp,
    uchars,
    bos: int,
    n_samples: int = 20,
    temperature: float = 0.5,
) -> None:
    """MAX Graph をコンパイルして推論サンプルを生成する。

    b08_max_infer.mojo の run_max_inference() から呼ばれる。
    """
    _, cpu_dev = _get_cpu()

    n_layer  = int(hp["n_layer"])
    n_embd   = int(hp["n_embd"])
    block_sz = int(hp["block_size"])

    # weight_arrays は PythonObject (Mojo から渡された Python dict)
    # numpy 配列に変換して使う
    wa = {k: np.asarray(weight_arrays[k]) for k in weight_arrays}

    g        = build_gpt_graph(wa, hp)
    reg      = _build_weights_registry(wa, hp)
    session  = engine.InferenceSession(devices=[cpu_dev])
    model    = session.load(g, weights_registry=reg)

    print("--- inference (MAX) ---")
    for si in range(int(n_samples)):
        k_caches = [np.zeros((0, n_embd), dtype=np.float32) for _ in range(n_layer)]
        v_caches = [np.zeros((0, n_embd), dtype=np.float32) for _ in range(n_layer)]

        token_id = int(bos)
        result   = []

        for pos_id in range(block_sz):
            # token_id / pos_id は [1] 形状で渡す(gather のインデックスが [1] 必須)
            inputs = [
                np.array([token_id], dtype=np.int32),
                np.array([pos_id],   dtype=np.int32),
            ]
            for li in range(n_layer):
                inputs.append(k_caches[li])
                inputs.append(v_caches[li])

            outputs = model.execute(*inputs)

            logits = np.from_dlpack(outputs[0])

            # temperature スケーリング + 確率的サンプリング
            logits_s = logits / float(temperature)
            probs = np.exp(logits_s - logits_s.max())
            probs /= probs.sum()
            token_id = int(np.random.choice(len(probs), p=probs))

            # KV キャッシュを更新
            # 出力順: [logits, k0, k1, ..., v0, v1, ...]
            for li in range(n_layer):
                k_caches[li] = np.from_dlpack(outputs[1 + li])
                v_caches[li] = np.from_dlpack(outputs[1 + n_layer + li])

            if token_id == int(bos):
                break
            result.append(str(uchars[token_id]))

        print(f"sample {si+1:2d}: {''.join(result)}")

b08_max_infer.mojo との対応:

Mojo 側(b08)の変数

Python 側(helper)での使われ方

weights["wte"] など

build_gpt_graph(weights, hp) に渡す

hp_dict

hp["n_embd"] などで参照

uchars_py

uchars[token_id] で文字を逆引き

bos

生成終端の判定に使用


28.6. main.mojo の変更点

学習ループは microgpt_mojo/main.mojoまったく同じです。 違いは末尾の推論部分だけです。

# microgpt_mojo_max — Mojo Tape で学習し、MAX Graph で推論する版。
#
# microgpt_mojo/main.mojo との差分:
#   - b07_infer (Tape ベースの推論) を使わない
#   - 学習後に b08_max_infer.run_max_inference() を呼び、
#     Python interop 経由で MAX Graph を使って推論する
#
# 実行方法(microgpt_mojo_max/ ディレクトリで):
#   mojo run main.mojo

from std.random import random_float64, seed

from b01_dataset import load_docs_from_text
from b02_tokenizer import build_tokenizer, encode_doc
from b03_value import Tape, backward
from b04_state_dict import HyperParams, init_state_dict, flatten_params
from b06_train import init_kv_cache, loss_on_document, adam_update
from b08_max_infer import run_max_inference


def main() raises:
    seed(42)  # 再現性のため乱数シードを固定

    # ── データセット ───────────────────────────────────────────────
    var f = open("input.txt", "r")
    var text = f.read()
    f.close()
    var docs = load_docs_from_text(text)
    # Fisher-Yates シャッフル(microgpt.py の random.shuffle に相当)
    for i in range(len(docs) - 1, 0, -1):
        var j = Int(random_float64() * Float64(i + 1))
        var tmp = docs[i]
        docs[i] = docs[j]
        docs[j] = tmp
    print("num docs:", len(docs))

    # ── トークナイザー ─────────────────────────────────────────────
    var tok = build_tokenizer(docs)
    print("vocab size:", tok.vocab_size())

    # ── ハイパーパラメータ ─────────────────────────────────────────
    # n_layer=1, n_embd=16, block_size=16, n_head=4(microgpt.py と同じ)
    var hp = HyperParams(1, 16, 16, 4)

    # ── テープとパラメータの初期化 ─────────────────────────────────
    var t = Tape()
    var sd = init_state_dict(t, tok.vocab_size(), hp, 0.08)
    var params = flatten_params(sd)
    print("num params:", len(params))

    # ── Adam バッファ ──────────────────────────────────────────────
    var m_buf = List[Float64]()
    var v_buf = List[Float64]()
    for _ in range(len(params)):
        m_buf.append(0.0)
        v_buf.append(0.0)

    # ── 学習ループ(microgpt_mojo と同じ)─────────────────────────
    var num_steps = 1000
    for step in range(num_steps):
        var doc = docs[step % len(docs)]

        var keys = List[List[List[Int]]]()
        var vals = List[List[List[Int]]]()
        init_kv_cache(hp.n_layer, keys, vals)

        var tokens = encode_doc(tok, doc)
        var loss_node = loss_on_document(t, sd, hp, tokens, keys, vals)
        backward(t, loss_node)
        adam_update(t, params, m_buf, v_buf, step, num_steps)
        print("step", step + 1, "/", num_steps, "| loss", t.node_data(loss_node))

    # ── MAX Graph による推論 ───────────────────────────────────────
    # Tape の重みを numpy に変換し、max_infer_helper.py の MAX Graph へ渡す。
    # b07_infer のスカラーループ推論は使わない。
    run_max_inference(t, sd, hp, tok.uchars, tok.bos, 20, 0.5)

差分まとめ:

変更前(microgpt_mojo)

変更後(microgpt_mojo_max)

from b07_infer import ...

from b08_max_infer import run_max_inference

Tape ベースの推論ループ

run_max_inference(t, sd, hp, ...) 1行


28.7. 比較実行

どちらも同じ input.txtseed(42) を使うため、 学習の loss 推移は一致します。

# Tape ベース推論(従来版)
cd src/part3/microgpt_mojo && mojo run main.mojo

# MAX Graph 推論(新版)
cd src/part3/microgpt_mojo_max && mojo run main.mojo

比較項目

結果

学習の loss 推移

同一(同じ Tape 実装・同じ seed)

推論結果

異なりうる(確率的サンプリングのため)

推論速度

MAX 版の方が速い(特に n_embd が大きい場合)


28.8. Tape 方式との速度比較

方式

推論 1 ステップの仕組み

特徴

Mojo Tape

スカラー演算をループで追記・即時実行

遅い、自動微分可能

MAX Graph

グラフをコンパイル後に行列演算として一括実行

速い、SIMD/GPU 対応

MAX Graph では行列演算をハードウェアの SIMD 命令や GPU カーネルにマッピングできるため、 n_embd=16 程度の小さなモデルでも推論のオーバーヘッドが大幅に減ります。 モデルが大きくなるほど(n_embd=256, n_layer=6 など)差が顕著になります。


28.9. まとめ

  • 学習(Tape + Adam、B1〜B6)はそのまま活かし、推論だけ MAX に切り替えるのが段階的な高速化の第一歩

  • 重みの受け渡しは Mojo Python interop で t.node_data() → numpy に変換するだけ

  • MAX Graph の定義(Python)は max_infer_helper.py に分離し、Mojo からは1行で呼ぶ

  • b08_max_infer.mojo が Mojo と MAX の橋渡し役になる

本書で実装した microgpt の知識は、より大きなモデルへのスケールアップや MAX/Mojo による最適化の基礎になります。