28. microgpt Mojo 版を MAX で高速化する
28.1. この章で学ぶこと
Mojo 版 microgpt(Tape ベース)の学習はそのまま維持し、推論だけ MAX に切り替える方法
Mojo の Tape ノード値を Python interop で numpy 配列に変換する方法
MAX Graph で GPT 推論ループを再構築する方法
microgpt_mojoとmicrogpt_mojo_maxを比較して結果を確認する方法
28.2. 方針:学習と推論を分離する
Mojo の Tape 実装はスカラーループのため推論が遅いですが、自動微分(学習)には最適です。 MAX Engine には学習(勾配計算)API がないため、学習は Tape のままにして推論だけ MAX に切り替えます。
フェーズ |
担当 |
microgpt_mojo との違い |
|---|---|---|
学習 |
Mojo Tape + Adam(B1〜B6) |
なし |
重みの抽出 |
|
新規 |
推論 |
|
新規(高速) |
28.3. ディレクトリ構成
src/part3/microgpt_mojo_max/
├── b01_dataset.mojo ← microgpt_mojo からコピー(変更なし)
├── b02_tokenizer.mojo ← 同上
├── b03_value.mojo ← 同上
├── b04_state_dict.mojo ← 同上
├── b05_ops.mojo ← 同上
├── b05_gpt.mojo ← 同上
├── b06_train.mojo ← 同上
├── b08_max_infer.mojo ← 新規:重み抽出 + MAX 呼び出し
├── max_infer_helper.py ← 新規:MAX Graph 定義・推論ループ
├── main.mojo ← 変更:推論部分を b08 に差し替え
└── input.txt ← microgpt_mojo からシンボリックリンク
28.4. B8: 重みの抽出と MAX 呼び出し(b08_max_infer.mojo)
b08_max_infer.mojo は Mojo と MAX(Python)の橋渡しをする唯一の新規 Mojo ファイルです。
# B8: MAX Graph を使った高速推論(Python interop 経由)。
# 学習済みの Tape ノード値を numpy 配列に変換し、
# max_infer_helper.py の MAX Graph 推論に渡す。
from std.python import Python, PythonObject
from b03_value import Tape
from b04_state_dict import HyperParams, StateDict
def _matrix_to_numpy(
t: Tape,
mat: List[List[Int]],
np: PythonObject,
) raises -> PythonObject:
"""List[List[NodeId]] を float32 ndarray に変換する。
mat[i][j] は Tape 上のノード ID(Int)。t.node_data() で値を取り出す。"""
var rows = len(mat)
var cols = len(mat[0])
var flat = Python.list()
for i in range(rows):
for j in range(cols):
flat.append(t.node_data(mat[i][j]))
return np.array(flat, dtype="float32").reshape(rows, cols)
def run_max_inference(
t: Tape,
sd: StateDict,
hp: HyperParams,
uchars: List[String],
bos: Int,
n_samples: Int,
temperature: Float64,
) raises:
"""学習済み重みを MAX Graph に渡して推論を実行する。
学習に使った Tape は変更しない(推論専用)。"""
var np = Python.import_module("numpy")
# ── Tape 上の重みノードを numpy 配列に変換 ──────────────────────
var weights = Python.dict()
weights["wte"] = _matrix_to_numpy(t, sd.wte, np)
weights["wpe"] = _matrix_to_numpy(t, sd.wpe, np)
weights["lm_head"] = _matrix_to_numpy(t, sd.lm_head, np)
for li in range(hp.n_layer):
var s = String(li)
weights["attn_wq_" + s] = _matrix_to_numpy(t, sd.attn_wq[li], np)
weights["attn_wk_" + s] = _matrix_to_numpy(t, sd.attn_wk[li], np)
weights["attn_wv_" + s] = _matrix_to_numpy(t, sd.attn_wv[li], np)
weights["attn_wo_" + s] = _matrix_to_numpy(t, sd.attn_wo[li], np)
weights["mlp_fc1_" + s] = _matrix_to_numpy(t, sd.mlp_fc1[li], np)
weights["mlp_fc2_" + s] = _matrix_to_numpy(t, sd.mlp_fc2[li], np)
# ── ハイパーパラメータを Python dict に ─────────────────────────
var hp_dict = Python.dict()
hp_dict["n_layer"] = hp.n_layer
hp_dict["n_embd"] = hp.n_embd
hp_dict["n_head"] = hp.n_head
hp_dict["block_size"] = hp.block_size
# ── uchars (List[String]) を Python list に ──────────────────────
var uchars_py = Python.list()
for i in range(len(uchars)):
uchars_py.append(uchars[i])
# ── MAX 推論ヘルパー(Python)を呼び出す ─────────────────────────
# max_infer_helper.py は同じディレクトリに置く
var sys = Python.import_module("sys")
sys.path.insert(0, ".")
var helper = Python.import_module("max_infer_helper")
helper.run_inference(weights, hp_dict, uchars_py, bos, n_samples, temperature)
ポイント:
_matrix_to_numpy()ではt.node_data(mat[i][j])で Tape ノードの値(Float64)を取り出し、Python.list()に積んでnp.array(...).reshape(rows, cols)で ndarray に変換するキー名は
"attn_wq_0","mlp_fc1_0"のように名前_層番号で統一するsys.path.insert(0, ".")で実行ディレクトリを Python パスに追加してからmax_infer_helperをインポートする
28.5. MAX Graph の定義(max_infer_helper.py)
MAX Graph のコード(EmbedLayer・AttentionLayer・MLPLayer・build_gpt_graph・run_inference)は
純粋な Python ファイルにまとめます。
"""
MAX Graph を使った GPT 推論ヘルパー。
b08_max_infer.mojo から Python interop 経由で呼ばれる。
インストール済み MAX のバージョンに合わせた API:
- Module クラスは使わない(この版には存在しない)
- TensorType(DType, shape, DeviceRef.CPU()) — device 引数が必須
- Weight(name, DType, shape, device)
- Graph('name', input_types=[...]) + g.inputs[i]
- InferenceSession(devices=[CPU()])
- session.load(g, weights_registry={name: ndarray})
- 出力は Buffer オブジェクト → np.from_dlpack() で変換
- ops.gather のインデックスは [1] 形状(スカラー不可)
"""
import math
import numpy as np
from max.graph import Graph, TensorType, Weight, DeviceRef, SymbolicDim
from max.graph import ops
from max.dtype import DType
from max import engine
from max.driver import CPU
_CPU_DEV = None # DeviceRef(グラフ定義用)
_CPU_RT = None # CPU driver(実行用)
def _get_cpu():
global _CPU_DEV, _CPU_RT
if _CPU_DEV is None:
_CPU_DEV = DeviceRef.CPU()
_CPU_RT = CPU()
return _CPU_DEV, _CPU_RT
# ─────────────────────────────────────────────────────────────────────────────
# グラフ内ヘルパー関数
# ─────────────────────────────────────────────────────────────────────────────
def _rmsnorm(x, eps: float = 1e-5):
"""x: [n_embd] → [n_embd]"""
sq = ops.mul(x, x)
ms = ops.mean(sq, axis=0)
inv = ops.rsqrt(ops.add(ms, eps))
return ops.mul(x, inv)
def _embed(wte: Weight, wpe: Weight, token_id, pos_id, n_embd: int):
"""gather 埋め込み + 位置埋め込み。インデックスは [1] 形状で渡す。"""
tok = ops.reshape(ops.gather(wte, token_id, axis=0), [n_embd])
pos = ops.reshape(ops.gather(wpe, pos_id, axis=0), [n_embd])
return _rmsnorm(ops.add(tok, pos))
def _attn_layer(weights, x, k_cache, v_cache, n_head, head_dim, n_embd, li):
"""1層分のマルチヘッドアテンション(KV キャッシュあり)"""
wq = weights[f"wq_{li}"]
wk = weights[f"wk_{li}"]
wv = weights[f"wv_{li}"]
wo = weights[f"wo_{li}"]
q = ops.reshape(ops.matmul(wq, ops.reshape(x, [n_embd, 1])), [n_embd])
k = ops.reshape(ops.matmul(wk, ops.reshape(x, [n_embd, 1])), [n_embd])
v = ops.reshape(ops.matmul(wv, ops.reshape(x, [n_embd, 1])), [n_embd])
# KV キャッシュに追記
k_cache = ops.concat([k_cache, ops.reshape(k, [1, n_embd])], axis=0)
v_cache = ops.concat([v_cache, ops.reshape(v, [1, n_embd])], axis=0)
scale = 1.0 / math.sqrt(head_dim)
heads = []
for h in range(n_head):
s = h * head_dim
# Python スライス構文でヘッドごとに分割
q_h = q[s : s + head_dim]
k_h = k_cache[:, s : s + head_dim] # [seq, head_dim]
v_h = v_cache[:, s : s + head_dim] # [seq, head_dim]
scores = ops.mul(
ops.reshape(ops.matmul(k_h, ops.reshape(q_h, [head_dim, 1])), [-1]),
scale,
)
w = ops.softmax(scores, axis=0)
head_out = ops.reshape(
ops.matmul(ops.reshape(w, [1, -1]), v_h), [head_dim]
)
heads.append(head_out)
merged = ops.concat(heads, axis=0)
out = ops.reshape(ops.matmul(wo, ops.reshape(merged, [n_embd, 1])), [n_embd])
return ops.add(out, x), k_cache, v_cache # 残差接続
def _mlp_layer(weights, x, n_embd, li):
"""2層 MLP(ReLU)。hidden = 4 × n_embd"""
fc1 = weights[f"fc1_{li}"]
fc2 = weights[f"fc2_{li}"]
h = ops.relu(ops.reshape(ops.matmul(fc1, ops.reshape(x, [-1, 1])), [-1]))
out = ops.reshape(ops.matmul(fc2, ops.reshape(h, [-1, 1])), [-1])
return ops.add(out, x) # 残差接続
# ─────────────────────────────────────────────────────────────────────────────
# グラフ構築
# ─────────────────────────────────────────────────────────────────────────────
def build_gpt_graph(weight_arrays: dict, hp: dict):
"""学習済み ndarray と hp から MAX Graph を返す。
入力テンソル:
token_id : int32 [1]
pos_id : int32 [1]
k_cache_0, v_cache_0, ... : float32 [seq, n_embd]
出力テンソル:
logits : float32 [vocab_size]
new_k_cache_0, new_v_cache_0, ...
"""
cpu, _ = _get_cpu()
n_embd = int(hp["n_embd"])
n_head = int(hp["n_head"])
head_dim = n_embd // n_head
n_layer = int(hp["n_layer"])
vocab = int(weight_arrays["wte"].shape[0])
block_size = int(weight_arrays["wpe"].shape[0])
# ── 入力型リスト ─────────────────────────────────────────────
# KV キャッシュのシーケンス長は可変なので SymbolicDim を使う
input_types = [
TensorType(DType.int32, [1], cpu), # token_id
TensorType(DType.int32, [1], cpu), # pos_id
]
for li in range(n_layer):
# k と v は同じシーケンス長なので同じ SymbolicDim を使う
seq = SymbolicDim(f"seq{li}")
input_types.append(TensorType(DType.float32, [seq, n_embd], cpu))
input_types.append(TensorType(DType.float32, [seq, n_embd], cpu))
with Graph("microgpt_infer", input_types=input_types) as g:
token_id = g.inputs[0]
pos_id = g.inputs[1]
k_caches = [g.inputs[2 + li * 2] for li in range(n_layer)]
v_caches = [g.inputs[2 + li * 2 + 1] for li in range(n_layer)]
# ── Weight ノード(graph 内に固定値として埋め込む) ────────
wte = Weight("wte", DType.float32, [vocab, n_embd], cpu)
wpe = Weight("wpe", DType.float32, [block_size, n_embd], cpu)
lm = Weight("lm_head", DType.float32, [vocab, n_embd], cpu)
w_graph = {} # str → Weight
for li in range(n_layer):
for nm in ["wq", "wk", "wv", "wo"]:
w_graph[f"{nm}_{li}"] = Weight(
f"{nm}_{li}", DType.float32, [n_embd, n_embd], cpu
)
# fc1: [4*n_embd, n_embd], fc2: [n_embd, 4*n_embd]
hidden = weight_arrays[f"mlp_fc1_{li}"].shape[0]
w_graph[f"fc1_{li}"] = Weight(f"fc1_{li}", DType.float32, [hidden, n_embd], cpu)
w_graph[f"fc2_{li}"] = Weight(f"fc2_{li}", DType.float32, [n_embd, hidden], cpu)
# ── 順伝播 ───────────────────────────────────────────────
x = _embed(wte, wpe, token_id, pos_id, n_embd)
new_k_caches, new_v_caches = [], []
for li in range(n_layer):
x = _rmsnorm(x)
x, new_k, new_v = _attn_layer(
w_graph, x, k_caches[li], v_caches[li], n_head, head_dim, n_embd, li
)
new_k_caches.append(new_k)
new_v_caches.append(new_v)
x = _rmsnorm(x)
x = _mlp_layer(w_graph, x, n_embd, li)
logits = ops.reshape(ops.matmul(lm, ops.reshape(x, [n_embd, 1])), [-1])
g.output(logits, *new_k_caches, *new_v_caches)
return g
def _build_weights_registry(weight_arrays: dict, hp: dict) -> dict:
"""ndarray dict を weights_registry(名前 → ndarray)に整形する。"""
n_layer = int(hp["n_layer"])
reg = {
"wte": weight_arrays["wte"],
"wpe": weight_arrays["wpe"],
"lm_head": weight_arrays["lm_head"],
}
for li in range(n_layer):
for nm in ["attn_wq", "attn_wk", "attn_wv", "attn_wo"]:
short = nm.replace("attn_", "") # "wq", "wk", ...
reg[f"{short}_{li}"] = weight_arrays[f"{nm}_{li}"]
reg[f"fc1_{li}"] = weight_arrays[f"mlp_fc1_{li}"]
reg[f"fc2_{li}"] = weight_arrays[f"mlp_fc2_{li}"]
return reg
# ─────────────────────────────────────────────────────────────────────────────
# 推論ループ
# ─────────────────────────────────────────────────────────────────────────────
def run_inference(
weight_arrays,
hp,
uchars,
bos: int,
n_samples: int = 20,
temperature: float = 0.5,
) -> None:
"""MAX Graph をコンパイルして推論サンプルを生成する。
b08_max_infer.mojo の run_max_inference() から呼ばれる。
"""
_, cpu_dev = _get_cpu()
n_layer = int(hp["n_layer"])
n_embd = int(hp["n_embd"])
block_sz = int(hp["block_size"])
# weight_arrays は PythonObject (Mojo から渡された Python dict)
# numpy 配列に変換して使う
wa = {k: np.asarray(weight_arrays[k]) for k in weight_arrays}
g = build_gpt_graph(wa, hp)
reg = _build_weights_registry(wa, hp)
session = engine.InferenceSession(devices=[cpu_dev])
model = session.load(g, weights_registry=reg)
print("--- inference (MAX) ---")
for si in range(int(n_samples)):
k_caches = [np.zeros((0, n_embd), dtype=np.float32) for _ in range(n_layer)]
v_caches = [np.zeros((0, n_embd), dtype=np.float32) for _ in range(n_layer)]
token_id = int(bos)
result = []
for pos_id in range(block_sz):
# token_id / pos_id は [1] 形状で渡す(gather のインデックスが [1] 必須)
inputs = [
np.array([token_id], dtype=np.int32),
np.array([pos_id], dtype=np.int32),
]
for li in range(n_layer):
inputs.append(k_caches[li])
inputs.append(v_caches[li])
outputs = model.execute(*inputs)
logits = np.from_dlpack(outputs[0])
# temperature スケーリング + 確率的サンプリング
logits_s = logits / float(temperature)
probs = np.exp(logits_s - logits_s.max())
probs /= probs.sum()
token_id = int(np.random.choice(len(probs), p=probs))
# KV キャッシュを更新
# 出力順: [logits, k0, k1, ..., v0, v1, ...]
for li in range(n_layer):
k_caches[li] = np.from_dlpack(outputs[1 + li])
v_caches[li] = np.from_dlpack(outputs[1 + n_layer + li])
if token_id == int(bos):
break
result.append(str(uchars[token_id]))
print(f"sample {si+1:2d}: {''.join(result)}")
b08_max_infer.mojo との対応:
Mojo 側(b08)の変数 |
Python 側(helper)での使われ方 |
|---|---|
|
|
|
|
|
|
|
生成終端の判定に使用 |
28.6. main.mojo の変更点
学習ループは microgpt_mojo/main.mojo とまったく同じです。
違いは末尾の推論部分だけです。
# microgpt_mojo_max — Mojo Tape で学習し、MAX Graph で推論する版。
#
# microgpt_mojo/main.mojo との差分:
# - b07_infer (Tape ベースの推論) を使わない
# - 学習後に b08_max_infer.run_max_inference() を呼び、
# Python interop 経由で MAX Graph を使って推論する
#
# 実行方法(microgpt_mojo_max/ ディレクトリで):
# mojo run main.mojo
from std.random import random_float64, seed
from b01_dataset import load_docs_from_text
from b02_tokenizer import build_tokenizer, encode_doc
from b03_value import Tape, backward
from b04_state_dict import HyperParams, init_state_dict, flatten_params
from b06_train import init_kv_cache, loss_on_document, adam_update
from b08_max_infer import run_max_inference
def main() raises:
seed(42) # 再現性のため乱数シードを固定
# ── データセット ───────────────────────────────────────────────
var f = open("input.txt", "r")
var text = f.read()
f.close()
var docs = load_docs_from_text(text)
# Fisher-Yates シャッフル(microgpt.py の random.shuffle に相当)
for i in range(len(docs) - 1, 0, -1):
var j = Int(random_float64() * Float64(i + 1))
var tmp = docs[i]
docs[i] = docs[j]
docs[j] = tmp
print("num docs:", len(docs))
# ── トークナイザー ─────────────────────────────────────────────
var tok = build_tokenizer(docs)
print("vocab size:", tok.vocab_size())
# ── ハイパーパラメータ ─────────────────────────────────────────
# n_layer=1, n_embd=16, block_size=16, n_head=4(microgpt.py と同じ)
var hp = HyperParams(1, 16, 16, 4)
# ── テープとパラメータの初期化 ─────────────────────────────────
var t = Tape()
var sd = init_state_dict(t, tok.vocab_size(), hp, 0.08)
var params = flatten_params(sd)
print("num params:", len(params))
# ── Adam バッファ ──────────────────────────────────────────────
var m_buf = List[Float64]()
var v_buf = List[Float64]()
for _ in range(len(params)):
m_buf.append(0.0)
v_buf.append(0.0)
# ── 学習ループ(microgpt_mojo と同じ)─────────────────────────
var num_steps = 1000
for step in range(num_steps):
var doc = docs[step % len(docs)]
var keys = List[List[List[Int]]]()
var vals = List[List[List[Int]]]()
init_kv_cache(hp.n_layer, keys, vals)
var tokens = encode_doc(tok, doc)
var loss_node = loss_on_document(t, sd, hp, tokens, keys, vals)
backward(t, loss_node)
adam_update(t, params, m_buf, v_buf, step, num_steps)
print("step", step + 1, "/", num_steps, "| loss", t.node_data(loss_node))
# ── MAX Graph による推論 ───────────────────────────────────────
# Tape の重みを numpy に変換し、max_infer_helper.py の MAX Graph へ渡す。
# b07_infer のスカラーループ推論は使わない。
run_max_inference(t, sd, hp, tok.uchars, tok.bos, 20, 0.5)
差分まとめ:
変更前(microgpt_mojo) |
変更後(microgpt_mojo_max) |
|---|---|
|
|
Tape ベースの推論ループ |
|
28.7. 比較実行
どちらも同じ input.txt と seed(42) を使うため、
学習の loss 推移は一致します。
# Tape ベース推論(従来版)
cd src/part3/microgpt_mojo && mojo run main.mojo
# MAX Graph 推論(新版)
cd src/part3/microgpt_mojo_max && mojo run main.mojo
比較項目 |
結果 |
|---|---|
学習の loss 推移 |
同一(同じ Tape 実装・同じ seed) |
推論結果 |
異なりうる(確率的サンプリングのため) |
推論速度 |
MAX 版の方が速い(特に |
28.8. Tape 方式との速度比較
方式 |
推論 1 ステップの仕組み |
特徴 |
|---|---|---|
Mojo Tape |
スカラー演算をループで追記・即時実行 |
遅い、自動微分可能 |
MAX Graph |
グラフをコンパイル後に行列演算として一括実行 |
速い、SIMD/GPU 対応 |
MAX Graph では行列演算をハードウェアの SIMD 命令や GPU カーネルにマッピングできるため、
n_embd=16 程度の小さなモデルでも推論のオーバーヘッドが大幅に減ります。
モデルが大きくなるほど(n_embd=256, n_layer=6 など)差が顕著になります。
28.9. まとめ
学習(Tape + Adam、B1〜B6)はそのまま活かし、推論だけ MAX に切り替えるのが段階的な高速化の第一歩
重みの受け渡しは Mojo Python interop で
t.node_data()→ numpy に変換するだけMAX Graph の定義(Python)は
max_infer_helper.pyに分離し、Mojo からは1行で呼ぶb08_max_infer.mojoが Mojo と MAX の橋渡し役になる
本書で実装した microgpt の知識は、より大きなモデルへのスケールアップや MAX/Mojo による最適化の基礎になります。