29. microgpt を PyTorch で書き直す
29.1. この章で学ぶこと
PyTorch の
autograd+nn.Moduleを使って microgpt.py を書き直す方法Apple Silicon の MPS(GPU)バックエンドで学習・推論を高速化する方法
同じロジックを Python 版と Mojo 版で書き比べ、Python interop の境界を理解する
29.2. なぜ PyTorch か
前章(MAX Graph)は推論専用でした。学習も含めて高速化したい場合、
PyTorch の autograd エンジンが有力な選択肢です。
実装 |
学習 |
推論 |
Apple Silicon |
|---|---|---|---|
microgpt.py |
Python スカラーループ |
Python スカラーループ |
非対応 |
microgpt_mojo |
Mojo スカラーループ |
Mojo スカラーループ |
非対応 |
microgpt_mojo_max |
Mojo Tape |
MAX Graph |
CPU のみ |
microgpt_torch.py |
PyTorch + MPS |
PyTorch + MPS |
MPS(GPU)対応 |
microgpt_torch_mojo |
PyTorch + MPS |
PyTorch + MPS |
MPS(GPU)対応 |
29.3. microgpt.py との対応
microgpt.py |
PyTorch 版 |
|---|---|
|
|
|
|
|
カスタム |
|
|
|
|
|
|
Adam バッファ手書き |
|
手書きの学習率減衰 |
|
29.4. ディレクトリ構成
29.4.1. Python 版(microgpt_torch.py)
単一ファイルにすべてを記述したシンプルな実装です。
src/part3/
└── microgpt_torch.py ← データ・モデル・学習・推論をすべて含む
29.4.2. Mojo 版(microgpt_torch_mojo/)
Python と Mojo の役割を分離した構成です。
src/part3/microgpt_torch_mojo/
├── dataset.py ← データ読み込み・トークナイザー(Python が得意な文字列処理)
├── model.py ← nn.Module クラス定義(PyTorch の型システムをそのまま活かす)
└── main.mojo ← 設定・学習ループ・推論(Mojo の型安全を活かす)
29.5. Python 版ソースコード
"""
microgpt.py を PyTorch で書き直した版。
学習と推論を PyTorch の autograd + MPS バックエンドで行う。
microgpt.py との対応:
class Value → torch.Tensor(不要:autograd 内蔵)
linear(), rmsnorm() → nn.Linear, RMSNorm(nn.Module)
gpt() 関数 → MicroGPT(nn.Module).forward()
手書き Adam → torch.optim.Adam
1 文書ずつ学習 → 1 文書(バッチサイズ 1)のまま(microgpt.py と比較しやすくするため)
"""
import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
random.seed(42)
torch.manual_seed(42)
# デバイス選択: Apple Silicon なら MPS(GPU)、それ以外は CPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"device: {device}")
# -----------------------------------------------------------------------------
# データセット(microgpt.py と同じ)
# -----------------------------------------------------------------------------
if not os.path.exists("input.txt"):
import urllib.request
names_url = "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt"
urllib.request.urlretrieve(names_url, "input.txt")
docs = [line.strip() for line in open("input.txt") if line.strip()]
random.shuffle(docs)
print(f"num docs: {len(docs)}")
# -----------------------------------------------------------------------------
# トークナイザー(microgpt.py と同じ)
# -----------------------------------------------------------------------------
uchars = sorted(set("".join(docs)))
BOS = len(uchars)
vocab_size = len(uchars) + 1
print(f"vocab size: {vocab_size}")
def encode(doc: str) -> list[int]:
return [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
# -----------------------------------------------------------------------------
# モデル定義
# -----------------------------------------------------------------------------
class RMSNorm(nn.Module):
"""microgpt.py の rmsnorm() に対応。スケールパラメータなし版。"""
def __init__(self, n_embd: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., n_embd]
ms = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(ms + self.eps)
class TransformerBlock(nn.Module):
"""1 層の Attention + MLP ブロック。microgpt.py の gpt() 内ループ 1 回分に対応。"""
def __init__(self, n_embd: int, n_head: int):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.ln1 = RMSNorm(n_embd)
self.wq = nn.Linear(n_embd, n_embd, bias=False)
self.wk = nn.Linear(n_embd, n_embd, bias=False)
self.wv = nn.Linear(n_embd, n_embd, bias=False)
self.wo = nn.Linear(n_embd, n_embd, bias=False)
self.ln2 = RMSNorm(n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd, bias=False),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch, seq, n_embd]
B, T, C = x.shape
H, D = self.n_head, self.head_dim
# Attention(因果マスク付き)
xn = self.ln1(x)
q = self.wq(xn).reshape(B, T, H, D).transpose(1, 2) # [B, H, T, D]
k = self.wk(xn).reshape(B, T, H, D).transpose(1, 2)
v = self.wv(xn).reshape(B, T, H, D).transpose(1, 2)
scale = 1.0 / math.sqrt(D)
attn = (q @ k.transpose(-2, -1)) * scale # [B, H, T, T]
# 因果マスク: 未来のトークンを参照しないように
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
attn = attn.masked_fill(mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, T, C) # [B, T, C]
out = self.wo(out)
x = x + out # 残差接続
# MLP
x = x + self.mlp(self.ln2(x)) # 残差接続
return x
class MicroGPT(nn.Module):
"""microgpt.py の gpt() 関数 + state_dict に対応する nn.Module。"""
def __init__(self, vocab_size: int, n_embd: int, n_head: int, n_layer: int, block_size: int):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(block_size, n_embd)
self.ln0 = RMSNorm(n_embd)
self.layers = nn.ModuleList([TransformerBlock(n_embd, n_head) for _ in range(n_layer)])
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
# token_ids: [batch, seq] → logits: [batch, seq, vocab]
B, T = token_ids.shape
pos_ids = torch.arange(T, device=token_ids.device)
x = self.wte(token_ids) + self.wpe(pos_ids) # [B, T, n_embd]
x = self.ln0(x)
for layer in self.layers:
x = layer(x)
return self.lm_head(x) # [B, T, vocab]
# -----------------------------------------------------------------------------
# モデル初期化
# -----------------------------------------------------------------------------
n_layer, n_embd, block_size, n_head = 1, 16, 16, 4
model = MicroGPT(vocab_size, n_embd, n_head, n_layer, block_size).to(device)
print(f"num params: {sum(p.numel() for p in model.parameters())}")
# -----------------------------------------------------------------------------
# 学習ループ
# -----------------------------------------------------------------------------
# microgpt.py と同じハイパーパラメータ
learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2), eps=eps_adam)
num_steps = 1000
model.train()
for step in range(num_steps):
doc = docs[step % len(docs)]
tokens_list = encode(doc)
n = min(block_size, len(tokens_list) - 1)
# [1, n+1] テンソルに変換してデバイスへ
tokens = torch.tensor(tokens_list[:n + 1], dtype=torch.long, device=device).unsqueeze(0)
# 順伝播
logits = model(tokens[:, :-1]) # [1, n, vocab]
targets = tokens[:, 1:] # [1, n]
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
# 逆伝播 + Adam 更新
# 学習率の線形減衰(microgpt.py と同じ)
lr_t = learning_rate * (1.0 - step / num_steps)
for pg in optimizer.param_groups:
pg["lr"] = lr_t
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.item():.4f}")
# -----------------------------------------------------------------------------
# 推論
# -----------------------------------------------------------------------------
temperature = 0.5
print("\n--- inference (PyTorch MPS) ---")
model.eval()
with torch.no_grad():
for sample_idx in range(20):
token_id = BOS
result = []
# KV キャッシュなし版: 毎ステップ先頭から再計算(microgpt.py と同じ動作)
generated = [BOS]
for pos_id in range(block_size):
tokens_in = torch.tensor([generated], dtype=torch.long, device=device)
logits = model(tokens_in) # [1, seq, vocab]
next_logit = logits[0, -1, :] / temperature # 最後の位置のlogits
probs = F.softmax(next_logit, dim=-1)
token_id = torch.multinomial(probs, 1).item()
if token_id == BOS:
break
generated.append(token_id)
result.append(uchars[token_id])
print(f"sample {sample_idx+1:2d}: {''.join(result)}")
29.6. Mojo 版ソースコード
29.6.1. dataset.py(データ処理)
"""
データセットとトークナイザー。
main.mojo から Python interop 経由で呼ばれる。
microgpt_torch.py のトップレベルコードを関数に切り出した版。
"""
import os
import random
def load_docs(seed: int = 42) -> list[str]:
"""input.txt を読み込み、シャッフルして返す。"""
if not os.path.exists("input.txt"):
import urllib.request
urllib.request.urlretrieve(
"https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt",
"input.txt",
)
docs = [line.strip() for line in open("input.txt") if line.strip()]
random.seed(seed)
random.shuffle(docs)
return docs
def make_uchars(docs: list[str]) -> list[str]:
"""ユニーク文字のソート済みリストを返す。"""
return sorted(set("".join(docs)))
def encode(doc: str, uchars: list[str], bos: int) -> list[int]:
"""文書を BOS 付きトークン列に変換する。"""
return [bos] + [uchars.index(ch) for ch in doc] + [bos]
29.6.2. model.py(PyTorch モデル定義)
"""
PyTorch モデル定義。main.mojo から Python interop 経由で使われる。
microgpt_torch.py の nn.Module クラスを切り出した版。
トップレベルの学習ループ・推論は main.mojo 側で行う。
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""microgpt.py の rmsnorm() に対応。スケールパラメータなし版。"""
def __init__(self, n_embd: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
ms = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(ms + self.eps)
class TransformerBlock(nn.Module):
"""1 層の Attention + MLP ブロック。"""
def __init__(self, n_embd: int, n_head: int):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.ln1 = RMSNorm(n_embd)
self.wq = nn.Linear(n_embd, n_embd, bias=False)
self.wk = nn.Linear(n_embd, n_embd, bias=False)
self.wv = nn.Linear(n_embd, n_embd, bias=False)
self.wo = nn.Linear(n_embd, n_embd, bias=False)
self.ln2 = RMSNorm(n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd, bias=False),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
H, D = self.n_head, self.head_dim
xn = self.ln1(x)
q = self.wq(xn).reshape(B, T, H, D).transpose(1, 2)
k = self.wk(xn).reshape(B, T, H, D).transpose(1, 2)
v = self.wv(xn).reshape(B, T, H, D).transpose(1, 2)
scale = 1.0 / math.sqrt(D)
attn = (q @ k.transpose(-2, -1)) * scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
attn = attn.masked_fill(mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, T, C)
out = self.wo(out)
x = x + out
x = x + self.mlp(self.ln2(x))
return x
class MicroGPT(nn.Module):
"""GPT モデル本体。main.mojo から Python interop 経由でインスタンス化される。"""
def __init__(
self,
vocab_size: int,
n_embd: int,
n_head: int,
n_layer: int,
block_size: int,
):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(block_size, n_embd)
self.ln0 = RMSNorm(n_embd)
self.layers = nn.ModuleList(
[TransformerBlock(n_embd, n_head) for _ in range(n_layer)]
)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
B, T = token_ids.shape
pos_ids = torch.arange(T, device=token_ids.device)
x = self.wte(token_ids) + self.wpe(pos_ids)
x = self.ln0(x)
for layer in self.layers:
x = layer(x)
return self.lm_head(x)
29.6.3. main.mojo(Mojo 版メインループ)
# microgpt_torch.py を Mojo で書き直した版。
#
# Python 側(dataset.py / model.py)との分担:
# dataset.py … データ読み込み・トークナイザー(文字列処理は Python が得意)
# model.py … nn.Module クラス定義(PyTorch の型システムを活かす)
# main.mojo … 設定・デバイス選択・学習ループ・推論(Mojo の型安全を活かす)
#
# microgpt_torch.py との主な違い:
# GPTConfig struct ← Python にはない型付き設定管理
# var device: String ← 型推論なしの明示的な String 型
# Int / Float64 ← Mojo の整数・浮動小数点型(Python の int/float とは別)
# .__bool__() ← PythonObject → Mojo Bool への明示的な変換
from std.python import Python, PythonObject
# ─────────────────────────────────────────────────────────────────────────────
# GPT ハイパーパラメータ(Mojo struct で型付き管理)
# microgpt_torch.py ではモジュールレベルの変数として定義していた部分
# ─────────────────────────────────────────────────────────────────────────────
struct GPTConfig:
var vocab_size: Int
var n_embd: Int
var n_head: Int
var n_layer: Int
var block_size: Int
def __init__(
out self,
vocab_size: Int,
n_embd: Int = 16,
n_head: Int = 4,
n_layer: Int = 1,
block_size: Int = 16,
):
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_head = n_head
self.n_layer = n_layer
self.block_size = block_size
def main() raises:
# ── Python モジュールを import ──────────────────────────────────────────
var sys = Python.import_module("sys")
var random = Python.import_module("random")
var torch = Python.import_module("torch")
var F = Python.import_module("torch.nn.functional")
var t_optim = Python.import_module("torch.optim")
# dataset.py / model.py はカレントディレクトリから import
sys.path.insert(0, ".")
var dataset = Python.import_module("dataset")
var model_mod = Python.import_module("model")
random.seed(42)
torch.manual_seed(42)
# ── デバイス選択(Mojo の String 変数に格納) ─────────────────────────
# microgpt_torch.py: device = "mps" if torch.backends.mps.is_available() else "cpu"
# Mojo では PythonObject の bool を .__bool__() で Mojo の Bool に変換する
var device: String = "cpu"
if torch.backends.mps.is_available().__bool__():
device = "mps"
print("device:", device)
# ── データセット・トークナイザー ────────────────────────────────────────
var docs = dataset.load_docs(42) # PythonObject(Python list)
var uchars = dataset.make_uchars(docs) # PythonObject(Python list)
# Mojo の len() は Python コレクションに対して Mojo Int を返す
var bos: Int = len(uchars)
var vocab_size: Int = bos + 1
print("num docs:", len(docs))
print("vocab size:", vocab_size)
# ── ハイパーパラメータ(Mojo struct) ────────────────────────────────────
# microgpt_torch.py: n_layer, n_embd, block_size, n_head = 1, 16, 16, 4
var cfg = GPTConfig(vocab_size=vocab_size)
# ── モデル初期化(PyTorch nn.Module を Python interop 経由で) ──────────
var model = model_mod.MicroGPT(
cfg.vocab_size, cfg.n_embd, cfg.n_head, cfg.n_layer, cfg.block_size,
)
_ = model.to(device)
# num_params は Python int で累積(PythonObject と Int の混算を避ける)
var num_params = Python.evaluate("0")
for p in model.parameters():
num_params = num_params + p.numel()
print("num params:", num_params)
# ── Adam オプティマイザー ──────────────────────────────────────────────
var learning_rate: Float64 = 0.01
# Mojo のタプルリテラル (0.85, 0.99) は Python タプルではないため
# Python.evaluate() で Python タプルを生成する
var betas = Python.evaluate("(0.85, 0.99)")
var optimizer = t_optim.Adam(
model.parameters(),
lr=learning_rate,
betas=betas,
eps=1e-8,
)
# ── 学習ループ(Mojo の for + 型付き変数) ────────────────────────────
var num_steps: Int = 1000
_ = model.train()
for step in range(num_steps):
# step は Mojo の Int(Python の int とは別の型)
var doc = docs[step % len(docs)]
var tok_list = dataset.encode(doc, uchars, bos)
var n: Int = min(cfg.block_size, len(tok_list) - 1)
var tokens = torch.tensor(
tok_list[0 : n + 1], dtype=torch.long, device=device,
).unsqueeze(0) # [1, n+1]
var logits = model(tokens[:, :-1]) # [1, n, vocab]
var targets = tokens[:, 1:] # [1, n]
var loss = F.cross_entropy(
logits.reshape(-1, cfg.vocab_size), targets.reshape(-1),
)
# 学習率の線形減衰
# Float64(step) で Mojo Int → Float64 に変換してから演算
var lr_t = learning_rate * (1.0 - Float64(step) / Float64(num_steps))
for pg in optimizer.param_groups:
# Python dict への書き込み: pg["lr"] = lr_t に相当
pg.__setitem__("lr", value=lr_t)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("step", step + 1, "/", num_steps, "| loss", loss.item())
# ── 推論(KV キャッシュなし) ────────────────────────────────────────
var temperature: Float64 = 0.5
print("\n--- inference (PyTorch / Mojo) ---")
_ = model.eval()
# 勾配計算を無効化(torch.no_grad() コンテキストの代替)
_ = torch.set_grad_enabled(False)
for si in range(20):
var generated = Python.list()
generated.append(bos)
var result = Python.list()
for _ in range(cfg.block_size):
# [[token_ids...]] の形で tensor を生成
var nested = Python.list()
nested.append(generated)
var tok_in = torch.tensor(nested, dtype=torch.long, device=device)
var logits = model(tok_in)
# logits[0][-1] で最後のトークンの logits を取得
# (Mojo から logits[0, -1, :] のマルチインデックススライスは
# 直接書けないため、段階的にインデックスする)
var next_logit = logits[0][-1] / temperature
var probs = F.softmax(next_logit, dim=-1)
# token_id は PythonObject のまま扱う
var token_id = torch.multinomial(probs, 1).item()
if token_id == bos:
break
generated.append(token_id)
result.append(uchars[token_id])
print("sample", si + 1, ":", Python.str("").join(result))
29.7. Python 版 vs Mojo 版:差分解説
29.7.1. 1. デバイス選択
Python 版:
device = "mps" if torch.backends.mps.is_available() else "cpu"
Mojo 版:
var device: String = "cpu"
if torch.backends.mps.is_available().__bool__():
device = "mps"
Mojo では PythonObject(PyTorch が返す Python の bool)を Mojo の Bool に変換するために .__bool__() が必要です。Python では型変換が暗黙的に行われますが、Mojo では明示的に記述します。
29.7.2. 2〜6. その他の主な差分
項目 |
Python 版 |
Mojo 版 |
ポイント |
|---|---|---|---|
ハイパーパラメータ |
変数バラ持ち |
|
型付きで渡し忘れをコンパイル時に検出 |
ループ変数 |
Python |
Mojo |
浮動小数との混算に |
Python タプル |
|
|
Mojo タプルは Python タプルと別型 |
dict への書き込み |
|
|
|
マルチインデックス |
|
|
段階的にインデックスして代替 |
29.8. 実行方法
# Python 版
uv run python src/part3/microgpt_torch.py
# Mojo 版(ディレクトリに移動してから実行)
cd src/part3/microgpt_torch_mojo
uv run mojo run main.mojo
どちらも同じ乱数シード(seed=42)を使うため、学習の loss 推移は一致します。
29.9. まとめ
比較項目 |
Python 版 |
Mojo 版 |
|---|---|---|
ファイル構成 |
単一ファイル |
役割分離(dataset/model/main) |
ハイパーパラメータ |
モジュール変数(動的型) |
|
デバイス選択 |
暗黙の型変換 |
|
ループ変数 |
Python int |
Mojo |
Python タプル |
タプルリテラル |
|
|
Python 内 |
Python 内(interop で呼び出す) |
次章では、同じ書き直しを Apple Silicon 専用フレームワーク MLX で行います。