29. microgpt を PyTorch で書き直す

29.1. この章で学ぶこと

  • PyTorch の autograd + nn.Module を使って microgpt.py を書き直す方法

  • Apple Silicon の MPS(GPU)バックエンドで学習・推論を高速化する方法

  • 同じロジックを Python 版Mojo 版で書き比べ、Python interop の境界を理解する


29.2. なぜ PyTorch か

前章(MAX Graph)は推論専用でした。学習も含めて高速化したい場合、 PyTorch の autograd エンジンが有力な選択肢です。

実装

学習

推論

Apple Silicon

microgpt.py

Python スカラーループ

Python スカラーループ

非対応

microgpt_mojo

Mojo スカラーループ

Mojo スカラーループ

非対応

microgpt_mojo_max

Mojo Tape

MAX Graph

CPU のみ

microgpt_torch.py

PyTorch + MPS

PyTorch + MPS

MPS(GPU)対応

microgpt_torch_mojo

PyTorch + MPS

PyTorch + MPS

MPS(GPU)対応


29.3. microgpt.py との対応

microgpt.py

PyTorch 版

class Value

torch.Tensor(autograd 内蔵なので不要)

linear(x, w)

nn.Linear(in, out, bias=False)

rmsnorm(x)

カスタム RMSNorm(nn.Module)

softmax(logits)

F.softmax(x, dim=-1)

gpt(token_id, ...)

MicroGPT.forward(tokens)

loss.backward()

loss.backward()

Adam バッファ手書き

torch.optim.Adam

手書きの学習率減衰

pg["lr"] = lr_t で同じロジックを再現


29.4. ディレクトリ構成

29.4.1. Python 版(microgpt_torch.py

単一ファイルにすべてを記述したシンプルな実装です。

src/part3/
└── microgpt_torch.py   ← データ・モデル・学習・推論をすべて含む

29.4.2. Mojo 版(microgpt_torch_mojo/

Python と Mojo の役割を分離した構成です。

src/part3/microgpt_torch_mojo/
├── dataset.py   ← データ読み込み・トークナイザー(Python が得意な文字列処理)
├── model.py     ← nn.Module クラス定義(PyTorch の型システムをそのまま活かす)
└── main.mojo    ← 設定・学習ループ・推論(Mojo の型安全を活かす)

29.5. Python 版ソースコード

"""
microgpt.py を PyTorch で書き直した版。
学習と推論を PyTorch の autograd + MPS バックエンドで行う。

microgpt.py との対応:
  class Value             → torch.Tensor(不要:autograd 内蔵)
  linear(), rmsnorm()     → nn.Linear, RMSNorm(nn.Module)
  gpt() 関数              → MicroGPT(nn.Module).forward()
  手書き Adam             → torch.optim.Adam
  1 文書ずつ学習          → 1 文書(バッチサイズ 1)のまま(microgpt.py と比較しやすくするため)
"""

import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

random.seed(42)
torch.manual_seed(42)

# デバイス選択: Apple Silicon なら MPS(GPU)、それ以外は CPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"device: {device}")

# -----------------------------------------------------------------------------
# データセット(microgpt.py と同じ)
# -----------------------------------------------------------------------------
if not os.path.exists("input.txt"):
    import urllib.request
    names_url = "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt"
    urllib.request.urlretrieve(names_url, "input.txt")
docs = [line.strip() for line in open("input.txt") if line.strip()]
random.shuffle(docs)
print(f"num docs: {len(docs)}")

# -----------------------------------------------------------------------------
# トークナイザー(microgpt.py と同じ)
# -----------------------------------------------------------------------------
uchars = sorted(set("".join(docs)))
BOS = len(uchars)
vocab_size = len(uchars) + 1
print(f"vocab size: {vocab_size}")


def encode(doc: str) -> list[int]:
    return [BOS] + [uchars.index(ch) for ch in doc] + [BOS]


# -----------------------------------------------------------------------------
# モデル定義
# -----------------------------------------------------------------------------

class RMSNorm(nn.Module):
    """microgpt.py の rmsnorm() に対応。スケールパラメータなし版。"""

    def __init__(self, n_embd: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [..., n_embd]
        ms = x.pow(2).mean(dim=-1, keepdim=True)
        return x * torch.rsqrt(ms + self.eps)


class TransformerBlock(nn.Module):
    """1 層の Attention + MLP ブロック。microgpt.py の gpt() 内ループ 1 回分に対応。"""

    def __init__(self, n_embd: int, n_head: int):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head

        self.ln1 = RMSNorm(n_embd)
        self.wq  = nn.Linear(n_embd, n_embd, bias=False)
        self.wk  = nn.Linear(n_embd, n_embd, bias=False)
        self.wv  = nn.Linear(n_embd, n_embd, bias=False)
        self.wo  = nn.Linear(n_embd, n_embd, bias=False)

        self.ln2  = RMSNorm(n_embd)
        self.mlp  = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd, bias=False),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd, bias=False),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch, seq, n_embd]
        B, T, C = x.shape
        H, D = self.n_head, self.head_dim

        # Attention(因果マスク付き)
        xn = self.ln1(x)
        q = self.wq(xn).reshape(B, T, H, D).transpose(1, 2)   # [B, H, T, D]
        k = self.wk(xn).reshape(B, T, H, D).transpose(1, 2)
        v = self.wv(xn).reshape(B, T, H, D).transpose(1, 2)

        scale = 1.0 / math.sqrt(D)
        attn = (q @ k.transpose(-2, -1)) * scale   # [B, H, T, T]
        # 因果マスク: 未来のトークンを参照しないように
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        attn = attn.masked_fill(mask, float("-inf"))
        attn = F.softmax(attn, dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(B, T, C)   # [B, T, C]
        out = self.wo(out)
        x = x + out   # 残差接続

        # MLP
        x = x + self.mlp(self.ln2(x))   # 残差接続
        return x


class MicroGPT(nn.Module):
    """microgpt.py の gpt() 関数 + state_dict に対応する nn.Module。"""

    def __init__(self, vocab_size: int, n_embd: int, n_head: int, n_layer: int, block_size: int):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, n_embd)
        self.wpe = nn.Embedding(block_size, n_embd)
        self.ln0  = RMSNorm(n_embd)
        self.layers = nn.ModuleList([TransformerBlock(n_embd, n_head) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        # token_ids: [batch, seq]  →  logits: [batch, seq, vocab]
        B, T = token_ids.shape
        pos_ids = torch.arange(T, device=token_ids.device)
        x = self.wte(token_ids) + self.wpe(pos_ids)   # [B, T, n_embd]
        x = self.ln0(x)
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(x)   # [B, T, vocab]


# -----------------------------------------------------------------------------
# モデル初期化
# -----------------------------------------------------------------------------
n_layer, n_embd, block_size, n_head = 1, 16, 16, 4
model = MicroGPT(vocab_size, n_embd, n_head, n_layer, block_size).to(device)
print(f"num params: {sum(p.numel() for p in model.parameters())}")

# -----------------------------------------------------------------------------
# 学習ループ
# -----------------------------------------------------------------------------
# microgpt.py と同じハイパーパラメータ
learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
optimizer = Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2), eps=eps_adam)

num_steps = 1000
model.train()
for step in range(num_steps):
    doc = docs[step % len(docs)]
    tokens_list = encode(doc)
    n = min(block_size, len(tokens_list) - 1)

    # [1, n+1] テンソルに変換してデバイスへ
    tokens = torch.tensor(tokens_list[:n + 1], dtype=torch.long, device=device).unsqueeze(0)

    # 順伝播
    logits = model(tokens[:, :-1])                        # [1, n, vocab]
    targets = tokens[:, 1:]                               # [1, n]
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))

    # 逆伝播 + Adam 更新
    # 学習率の線形減衰(microgpt.py と同じ)
    lr_t = learning_rate * (1.0 - step / num_steps)
    for pg in optimizer.param_groups:
        pg["lr"] = lr_t

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.item():.4f}")

# -----------------------------------------------------------------------------
# 推論
# -----------------------------------------------------------------------------
temperature = 0.5
print("\n--- inference (PyTorch MPS) ---")
model.eval()
with torch.no_grad():
    for sample_idx in range(20):
        token_id = BOS
        result = []
        # KV キャッシュなし版: 毎ステップ先頭から再計算(microgpt.py と同じ動作)
        generated = [BOS]
        for pos_id in range(block_size):
            tokens_in = torch.tensor([generated], dtype=torch.long, device=device)
            logits = model(tokens_in)                     # [1, seq, vocab]
            next_logit = logits[0, -1, :] / temperature  # 最後の位置のlogits
            probs = F.softmax(next_logit, dim=-1)
            token_id = torch.multinomial(probs, 1).item()
            if token_id == BOS:
                break
            generated.append(token_id)
            result.append(uchars[token_id])
        print(f"sample {sample_idx+1:2d}: {''.join(result)}")

29.6. Mojo 版ソースコード

29.6.1. dataset.py(データ処理)

"""
データセットとトークナイザー。
main.mojo から Python interop 経由で呼ばれる。

microgpt_torch.py のトップレベルコードを関数に切り出した版。
"""

import os
import random


def load_docs(seed: int = 42) -> list[str]:
    """input.txt を読み込み、シャッフルして返す。"""
    if not os.path.exists("input.txt"):
        import urllib.request
        urllib.request.urlretrieve(
            "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt",
            "input.txt",
        )
    docs = [line.strip() for line in open("input.txt") if line.strip()]
    random.seed(seed)
    random.shuffle(docs)
    return docs


def make_uchars(docs: list[str]) -> list[str]:
    """ユニーク文字のソート済みリストを返す。"""
    return sorted(set("".join(docs)))


def encode(doc: str, uchars: list[str], bos: int) -> list[int]:
    """文書を BOS 付きトークン列に変換する。"""
    return [bos] + [uchars.index(ch) for ch in doc] + [bos]

29.6.2. model.py(PyTorch モデル定義)

"""
PyTorch モデル定義。main.mojo から Python interop 経由で使われる。

microgpt_torch.py の nn.Module クラスを切り出した版。
トップレベルの学習ループ・推論は main.mojo 側で行う。
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
    """microgpt.py の rmsnorm() に対応。スケールパラメータなし版。"""

    def __init__(self, n_embd: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ms = x.pow(2).mean(dim=-1, keepdim=True)
        return x * torch.rsqrt(ms + self.eps)


class TransformerBlock(nn.Module):
    """1 層の Attention + MLP ブロック。"""

    def __init__(self, n_embd: int, n_head: int):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head

        self.ln1 = RMSNorm(n_embd)
        self.wq  = nn.Linear(n_embd, n_embd, bias=False)
        self.wk  = nn.Linear(n_embd, n_embd, bias=False)
        self.wv  = nn.Linear(n_embd, n_embd, bias=False)
        self.wo  = nn.Linear(n_embd, n_embd, bias=False)

        self.ln2 = RMSNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd, bias=False),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd, bias=False),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        H, D = self.n_head, self.head_dim

        xn = self.ln1(x)
        q = self.wq(xn).reshape(B, T, H, D).transpose(1, 2)
        k = self.wk(xn).reshape(B, T, H, D).transpose(1, 2)
        v = self.wv(xn).reshape(B, T, H, D).transpose(1, 2)

        scale = 1.0 / math.sqrt(D)
        attn = (q @ k.transpose(-2, -1)) * scale
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        attn = attn.masked_fill(mask, float("-inf"))
        attn = F.softmax(attn, dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(B, T, C)
        out = self.wo(out)
        x = x + out
        x = x + self.mlp(self.ln2(x))
        return x


class MicroGPT(nn.Module):
    """GPT モデル本体。main.mojo から Python interop 経由でインスタンス化される。"""

    def __init__(
        self,
        vocab_size: int,
        n_embd: int,
        n_head: int,
        n_layer: int,
        block_size: int,
    ):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, n_embd)
        self.wpe = nn.Embedding(block_size, n_embd)
        self.ln0 = RMSNorm(n_embd)
        self.layers = nn.ModuleList(
            [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]
        )
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        B, T = token_ids.shape
        pos_ids = torch.arange(T, device=token_ids.device)
        x = self.wte(token_ids) + self.wpe(pos_ids)
        x = self.ln0(x)
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(x)

29.6.3. main.mojo(Mojo 版メインループ)

# microgpt_torch.py を Mojo で書き直した版。
#
# Python 側(dataset.py / model.py)との分担:
#   dataset.py  … データ読み込み・トークナイザー(文字列処理は Python が得意)
#   model.py    … nn.Module クラス定義(PyTorch の型システムを活かす)
#   main.mojo   … 設定・デバイス選択・学習ループ・推論(Mojo の型安全を活かす)
#
# microgpt_torch.py との主な違い:
#   GPTConfig struct   ← Python にはない型付き設定管理
#   var device: String ← 型推論なしの明示的な String 型
#   Int / Float64      ← Mojo の整数・浮動小数点型(Python の int/float とは別)
#   .__bool__()        ← PythonObject → Mojo Bool への明示的な変換

from std.python import Python, PythonObject


# ─────────────────────────────────────────────────────────────────────────────
# GPT ハイパーパラメータ(Mojo struct で型付き管理)
# microgpt_torch.py ではモジュールレベルの変数として定義していた部分
# ─────────────────────────────────────────────────────────────────────────────
struct GPTConfig:
    var vocab_size: Int
    var n_embd: Int
    var n_head: Int
    var n_layer: Int
    var block_size: Int

    def __init__(
        out self,
        vocab_size: Int,
        n_embd: Int = 16,
        n_head: Int = 4,
        n_layer: Int = 1,
        block_size: Int = 16,
    ):
        self.vocab_size = vocab_size
        self.n_embd     = n_embd
        self.n_head     = n_head
        self.n_layer    = n_layer
        self.block_size = block_size


def main() raises:
    # ── Python モジュールを import ──────────────────────────────────────────
    var sys     = Python.import_module("sys")
    var random  = Python.import_module("random")
    var torch   = Python.import_module("torch")
    var F       = Python.import_module("torch.nn.functional")
    var t_optim = Python.import_module("torch.optim")
    # dataset.py / model.py はカレントディレクトリから import
    sys.path.insert(0, ".")
    var dataset   = Python.import_module("dataset")
    var model_mod = Python.import_module("model")

    random.seed(42)
    torch.manual_seed(42)

    # ── デバイス選択(Mojo の String 変数に格納) ─────────────────────────
    # microgpt_torch.py: device = "mps" if torch.backends.mps.is_available() else "cpu"
    # Mojo では PythonObject の bool を .__bool__() で Mojo の Bool に変換する
    var device: String = "cpu"
    if torch.backends.mps.is_available().__bool__():
        device = "mps"
    print("device:", device)

    # ── データセット・トークナイザー ────────────────────────────────────────
    var docs   = dataset.load_docs(42)     # PythonObject(Python list)
    var uchars = dataset.make_uchars(docs) # PythonObject(Python list)
    # Mojo の len() は Python コレクションに対して Mojo Int を返す
    var bos: Int        = len(uchars)
    var vocab_size: Int = bos + 1
    print("num docs:", len(docs))
    print("vocab size:", vocab_size)

    # ── ハイパーパラメータ(Mojo struct) ────────────────────────────────────
    # microgpt_torch.py: n_layer, n_embd, block_size, n_head = 1, 16, 16, 4
    var cfg = GPTConfig(vocab_size=vocab_size)

    # ── モデル初期化(PyTorch nn.Module を Python interop 経由で) ──────────
    var model = model_mod.MicroGPT(
        cfg.vocab_size, cfg.n_embd, cfg.n_head, cfg.n_layer, cfg.block_size,
    )
    _ = model.to(device)

    # num_params は Python int で累積(PythonObject と Int の混算を避ける)
    var num_params = Python.evaluate("0")
    for p in model.parameters():
        num_params = num_params + p.numel()
    print("num params:", num_params)

    # ── Adam オプティマイザー ──────────────────────────────────────────────
    var learning_rate: Float64 = 0.01
    # Mojo のタプルリテラル (0.85, 0.99) は Python タプルではないため
    # Python.evaluate() で Python タプルを生成する
    var betas = Python.evaluate("(0.85, 0.99)")
    var optimizer = t_optim.Adam(
        model.parameters(),
        lr=learning_rate,
        betas=betas,
        eps=1e-8,
    )

    # ── 学習ループ(Mojo の for + 型付き変数) ────────────────────────────
    var num_steps: Int = 1000
    _ = model.train()

    for step in range(num_steps):
        # step は Mojo の Int(Python の int とは別の型)
        var doc      = docs[step % len(docs)]
        var tok_list = dataset.encode(doc, uchars, bos)
        var n: Int   = min(cfg.block_size, len(tok_list) - 1)

        var tokens = torch.tensor(
            tok_list[0 : n + 1], dtype=torch.long, device=device,
        ).unsqueeze(0)                             # [1, n+1]

        var logits  = model(tokens[:, :-1])        # [1, n, vocab]
        var targets = tokens[:, 1:]               # [1, n]
        var loss    = F.cross_entropy(
            logits.reshape(-1, cfg.vocab_size), targets.reshape(-1),
        )

        # 学習率の線形減衰
        # Float64(step) で Mojo Int → Float64 に変換してから演算
        var lr_t = learning_rate * (1.0 - Float64(step) / Float64(num_steps))
        for pg in optimizer.param_groups:
            # Python dict への書き込み: pg["lr"] = lr_t に相当
            pg.__setitem__("lr", value=lr_t)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print("step", step + 1, "/", num_steps, "| loss", loss.item())

    # ── 推論(KV キャッシュなし) ────────────────────────────────────────
    var temperature: Float64 = 0.5
    print("\n--- inference (PyTorch / Mojo) ---")
    _ = model.eval()
    # 勾配計算を無効化(torch.no_grad() コンテキストの代替)
    _ = torch.set_grad_enabled(False)

    for si in range(20):
        var generated = Python.list()
        generated.append(bos)
        var result = Python.list()

        for _ in range(cfg.block_size):
            # [[token_ids...]] の形で tensor を生成
            var nested = Python.list()
            nested.append(generated)
            var tok_in = torch.tensor(nested, dtype=torch.long, device=device)

            var logits = model(tok_in)
            # logits[0][-1] で最後のトークンの logits を取得
            # (Mojo から logits[0, -1, :] のマルチインデックススライスは
            #   直接書けないため、段階的にインデックスする)
            var next_logit = logits[0][-1] / temperature
            var probs      = F.softmax(next_logit, dim=-1)
            # token_id は PythonObject のまま扱う
            var token_id   = torch.multinomial(probs, 1).item()
            if token_id == bos:
                break
            generated.append(token_id)
            result.append(uchars[token_id])

        print("sample", si + 1, ":", Python.str("").join(result))

29.7. Python 版 vs Mojo 版:差分解説

29.7.1. 1. デバイス選択

Python 版:

device = "mps" if torch.backends.mps.is_available() else "cpu"

Mojo 版:

var device: String = "cpu"
if torch.backends.mps.is_available().__bool__():
    device = "mps"

Mojo では PythonObject(PyTorch が返す Python の bool)を Mojo の Bool に変換するために .__bool__() が必要です。Python では型変換が暗黙的に行われますが、Mojo では明示的に記述します。


29.7.2. 2〜6. その他の主な差分

項目

Python 版

Mojo 版

ポイント

ハイパーパラメータ

変数バラ持ち

struct GPTConfig

型付きで渡し忘れをコンパイル時に検出

ループ変数

Python int

Mojo Int

浮動小数との混算に Float64(step) が必要

Python タプル

(0.85, 0.99)

Python.evaluate("(0.85, 0.99)")

Mojo タプルは Python タプルと別型

dict への書き込み

pg["lr"] = lr_t

pg.__setitem__("lr", value=lr_t)

[] 代入は Mojo では使えない

マルチインデックス

logits[0, -1, :]

logits[0][-1]

段階的にインデックスして代替


29.8. 実行方法

# Python 版
uv run python src/part3/microgpt_torch.py

# Mojo 版(ディレクトリに移動してから実行)
cd src/part3/microgpt_torch_mojo
uv run mojo run main.mojo

どちらも同じ乱数シード(seed=42)を使うため、学習の loss 推移は一致します。


29.9. まとめ

比較項目

Python 版

Mojo 版

ファイル構成

単一ファイル

役割分離(dataset/model/main)

ハイパーパラメータ

モジュール変数(動的型)

struct GPTConfig(静的型)

デバイス選択

暗黙の型変換

.__bool__() で明示変換

ループ変数

Python int

Mojo Int

Python タプル

タプルリテラル

Python.evaluate() で生成

nn.Module の定義

Python 内

Python 内(interop で呼び出す)

次章では、同じ書き直しを Apple Silicon 専用フレームワーク MLX で行います。