30. microgpt を MLX で書き直す

30.1. この章で学ぶこと

  • Apple Silicon 専用の MLX フレームワークで microgpt.py を書き直す方法

  • MLX の Lazy Evaluation(遅延評価)mx.eval() の役割を理解する

  • 同じロジックを Python 版Mojo 版で書き比べ、MLX 固有の interop パターンを習得する


30.2. なぜ MLX か

MLX が本書で最速(3.0 秒)を出した理由は 2 つです。

  • Unified Memory — Apple Silicon では CPU と GPU がメモリを共有し、データコピーが不要。PyTorch MPS は CPU-GPU 間コピーが発生するがMLX は不要。

  • Lazy Evaluation(遅延評価) — すべての演算がデフォルトで遅延実行され、mx.eval() で確定するまでグラフ全体を最適化できる。

MLX は Apple が開発した、Apple Silicon(M1/M2/M3)専用の機械学習フレームワークです。

特徴

PyTorch(MPS)

MLX

CPU/GPU メモリ

別々(コピーが発生)

統合(コピー不要)

API スタイル

オブジェクト指向

NumPy に近い関数型

遅延評価

一部のみ

全演算がデフォルトで遅延

対応ハードウェア

汎用 + MPS

Apple Silicon 専用


30.3. microgpt.py との対応

microgpt.py

MLX 版

class Value

mx.array(autograd 内蔵なので不要)

linear(x, w)

nn.Linear(in, out, bias=False)

rmsnorm(x)

カスタム RMSNorm(nn.Module)

softmax(logits)

mx.softmax(x, axis=-1)

gpt(token_id, ...)

MicroGPT.__call__(tokens)

loss.backward()

nn.value_and_grad() で loss と勾配を同時計算

Adam バッファ手書き

optim.Adam

手書きの学習率減衰

optimizer.learning_rate = lr_t


30.4. ディレクトリ構成

30.4.1. Python 版(microgpt_mlx.py

単一ファイルにすべてを記述したシンプルな実装です。

src/part3/
└── microgpt_mlx.py   ← データ・モデル・学習・推論をすべて含む

30.4.2. Mojo 版(microgpt_mlx_mojo/

Python と Mojo の役割を分離した構成です。

src/part3/microgpt_mlx_mojo/
├── dataset.py   ← データ読み込み・トークナイザー
├── model.py     ← MLX nn.Module・loss 関数・パラメータカウント
└── main.mojo    ← 設定・学習ループ・推論(Mojo の型安全を活かす)

30.5. Python 版ソースコード

"""
microgpt.py を MLX で書き直した版。
学習と推論を MLX の autograd + Apple Silicon Unified Memory で行う。

microgpt.py との対応:
  class Value             → mx.array(不要:autograd 内蔵)
  linear(), rmsnorm()     → nn.Linear, nn.RMSNorm
  gpt() 関数              → MicroGPT(nn.Module).__call__()
  手書き Adam             → optim.Adam
  1 文書ずつ学習          → 1 文書(バッチサイズ 1)のまま(microgpt.py と比較しやすくするため)
"""

import os
import math
import random
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten

random.seed(42)
mx.random.seed(42)

# -----------------------------------------------------------------------------
# データセット(microgpt.py と同じ)
# -----------------------------------------------------------------------------
if not os.path.exists("input.txt"):
    import urllib.request
    names_url = "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt"
    urllib.request.urlretrieve(names_url, "input.txt")
docs = [line.strip() for line in open("input.txt") if line.strip()]
random.shuffle(docs)
print(f"num docs: {len(docs)}")

# -----------------------------------------------------------------------------
# トークナイザー(microgpt.py と同じ)
# -----------------------------------------------------------------------------
uchars = sorted(set("".join(docs)))
BOS = len(uchars)
vocab_size = len(uchars) + 1
print(f"vocab size: {vocab_size}")


def encode(doc: str) -> list[int]:
    return [BOS] + [uchars.index(ch) for ch in doc] + [BOS]


# -----------------------------------------------------------------------------
# モデル定義
# -----------------------------------------------------------------------------

class RMSNorm(nn.Module):
    """microgpt.py の rmsnorm() に対応。nn.RMSNorm はスケールなし版を自前実装。"""

    def __init__(self, n_embd: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps

    def __call__(self, x: mx.array) -> mx.array:
        # x: [..., n_embd]
        ms = mx.mean(x * x, axis=-1, keepdims=True)
        return x * mx.rsqrt(ms + self.eps)


class TransformerBlock(nn.Module):
    """1 層の Attention + MLP ブロック。microgpt.py の gpt() 内ループ 1 回分に対応。"""

    def __init__(self, n_embd: int, n_head: int):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head

        self.ln1 = RMSNorm(n_embd)
        self.wq  = nn.Linear(n_embd, n_embd, bias=False)
        self.wk  = nn.Linear(n_embd, n_embd, bias=False)
        self.wv  = nn.Linear(n_embd, n_embd, bias=False)
        self.wo  = nn.Linear(n_embd, n_embd, bias=False)

        self.ln2 = RMSNorm(n_embd)
        self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False)
        self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=False)

    def __call__(self, x: mx.array) -> mx.array:
        # x: [batch, seq, n_embd]
        B, T, C = x.shape
        H, D = self.n_head, self.head_dim

        # Attention(因果マスク付き)
        xn = self.ln1(x)
        q = self.wq(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)   # [B, H, T, D]
        k = self.wk(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
        v = self.wv(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)

        scale = 1.0 / math.sqrt(D)
        attn = (q @ k.transpose(0, 1, 3, 2)) * scale   # [B, H, T, T]
        # 因果マスク: 未来のトークンを参照しないように
        mask = mx.triu(mx.ones((T, T)), k=1).astype(mx.bool_)
        attn = mx.where(mask, mx.array(float("-inf")), attn)
        attn = mx.softmax(attn, axis=-1)

        out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, C)   # [B, T, C]
        out = self.wo(out)
        x = x + out   # 残差接続

        # MLP
        h = nn.relu(self.fc1(self.ln2(x)))
        x = x + self.fc2(h)   # 残差接続
        return x


class MicroGPT(nn.Module):
    """microgpt.py の gpt() 関数 + state_dict に対応する nn.Module。"""

    def __init__(self, vocab_size: int, n_embd: int, n_head: int, n_layer: int, block_size: int):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, n_embd)
        self.wpe = nn.Embedding(block_size, n_embd)
        self.ln0 = RMSNorm(n_embd)
        self.layers = [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

    def __call__(self, token_ids: mx.array) -> mx.array:
        # token_ids: [batch, seq]  →  logits: [batch, seq, vocab]
        B, T = token_ids.shape
        pos_ids = mx.arange(T)
        x = self.wte(token_ids) + self.wpe(pos_ids)   # [B, T, n_embd]
        x = self.ln0(x)
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(x)   # [B, T, vocab]


# -----------------------------------------------------------------------------
# モデル初期化
# -----------------------------------------------------------------------------
n_layer, n_embd, block_size, n_head = 1, 16, 16, 4
model = MicroGPT(vocab_size, n_embd, n_head, n_layer, block_size)
mx.eval(model.parameters())
num_params = sum(p.size for _, p in tree_flatten(model.parameters()))
print(f"num params: {num_params}")

# -----------------------------------------------------------------------------
# 学習ループ
# -----------------------------------------------------------------------------
# microgpt.py と同じハイパーパラメータ
learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
optimizer = optim.Adam(learning_rate=learning_rate, betas=(beta1, beta2), eps=eps_adam)


def loss_fn(model, tokens):
    """tokens: [1, seq+1] の整数配列"""
    logits = model(tokens[:, :-1])                        # [1, seq, vocab]
    targets = tokens[:, 1:]                               # [1, seq]
    # cross entropy: logits [N, vocab], targets [N]
    return mx.mean(nn.losses.cross_entropy(
        logits.reshape(-1, vocab_size),
        targets.reshape(-1),
    ))


# nn.value_and_grad: loss と勾配を同時に計算する MLX のパターン
loss_and_grad = nn.value_and_grad(model, loss_fn)

num_steps = 1000
for step in range(num_steps):
    doc = docs[step % len(docs)]
    tokens_list = encode(doc)
    n = min(block_size, len(tokens_list) - 1)

    tokens = mx.array(tokens_list[:n + 1], dtype=mx.int32)[None, :]   # [1, n+1]

    # 順伝播 + 逆伝播(MLX は lazy evaluation → mx.eval で確定)
    loss, grads = loss_and_grad(model, tokens)

    # 学習率の線形減衰(microgpt.py と同じ)
    lr_t = learning_rate * (1.0 - step / num_steps)
    optimizer.learning_rate = lr_t

    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state, loss)   # lazy eval を確定

    print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.item():.4f}")

# -----------------------------------------------------------------------------
# 推論
# -----------------------------------------------------------------------------
temperature = 0.5
print("\n--- inference (MLX) ---")
for sample_idx in range(20):
    token_id = BOS
    result = []
    generated = [BOS]
    for pos_id in range(block_size):
        tokens_in = mx.array([generated], dtype=mx.int32)
        logits = model(tokens_in)                         # [1, seq, vocab]
        next_logit = logits[0, -1, :] / temperature       # 最後の位置のlogits
        probs = mx.softmax(next_logit, axis=-1)
        mx.eval(probs)
        probs_np = np.array(probs.tolist(), dtype=np.float64)
        probs_np = probs_np / probs_np.sum()   # 数値誤差を正規化
        token_id = int(np.random.choice(len(probs_np), p=probs_np))
        if token_id == BOS:
            break
        generated.append(token_id)
        result.append(uchars[token_id])
    print(f"sample {sample_idx+1:2d}: {''.join(result)}")

30.6. Mojo 版ソースコード

30.6.1. dataset.py(データ処理)

"""
データセットとトークナイザー。
main.mojo から Python interop 経由で呼ばれる。

microgpt_mlx.py のトップレベルコードを関数に切り出した版。
"""

import os
import random


def load_docs(seed: int = 42) -> list[str]:
    """input.txt を読み込み、シャッフルして返す。"""
    if not os.path.exists("input.txt"):
        import urllib.request
        urllib.request.urlretrieve(
            "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt",
            "input.txt",
        )
    docs = [line.strip() for line in open("input.txt") if line.strip()]
    random.seed(seed)
    random.shuffle(docs)
    return docs


def make_uchars(docs: list[str]) -> list[str]:
    """ユニーク文字のソート済みリストを返す。"""
    return sorted(set("".join(docs)))


def encode(doc: str, uchars: list[str], bos: int) -> list[int]:
    """文書を BOS 付きトークン列に変換する。"""
    return [bos] + [uchars.index(ch) for ch in doc] + [bos]

30.6.2. model.py(MLX モデル定義)

"""
MLX モデル定義。main.mojo から Python interop 経由で使われる。

microgpt_mlx.py の nn.Module クラスと loss 関数を切り出した版。
トップレベルの学習ループ・推論は main.mojo 側で行う。
"""

import math
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten


class RMSNorm(nn.Module):
    """microgpt.py の rmsnorm() に対応。スケールパラメータなし版。"""

    def __init__(self, n_embd: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps

    def __call__(self, x: mx.array) -> mx.array:
        ms = mx.mean(x * x, axis=-1, keepdims=True)
        return x * mx.rsqrt(ms + self.eps)


class TransformerBlock(nn.Module):
    """1 層の Attention + MLP ブロック。"""

    def __init__(self, n_embd: int, n_head: int):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head   = n_head
        self.head_dim = n_embd // n_head

        self.ln1 = RMSNorm(n_embd)
        self.wq  = nn.Linear(n_embd, n_embd, bias=False)
        self.wk  = nn.Linear(n_embd, n_embd, bias=False)
        self.wv  = nn.Linear(n_embd, n_embd, bias=False)
        self.wo  = nn.Linear(n_embd, n_embd, bias=False)

        self.ln2 = RMSNorm(n_embd)
        self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False)
        self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=False)

    def __call__(self, x: mx.array) -> mx.array:
        B, T, C = x.shape
        H, D = self.n_head, self.head_dim

        xn = self.ln1(x)
        q = self.wq(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
        k = self.wk(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
        v = self.wv(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)

        scale = 1.0 / math.sqrt(D)
        attn = (q @ k.transpose(0, 1, 3, 2)) * scale
        mask = mx.triu(mx.ones((T, T)), k=1).astype(mx.bool_)
        attn = mx.where(mask, mx.array(float("-inf")), attn)
        attn = mx.softmax(attn, axis=-1)

        out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
        out = self.wo(out)
        x = x + out
        h = nn.relu(self.fc1(self.ln2(x)))
        x = x + self.fc2(h)
        return x


class MicroGPT(nn.Module):
    """GPT モデル本体。main.mojo から Python interop 経由でインスタンス化される。"""

    def __init__(
        self,
        vocab_size: int,
        n_embd: int,
        n_head: int,
        n_layer: int,
        block_size: int,
    ):
        super().__init__()
        self.wte     = nn.Embedding(vocab_size, n_embd)
        self.wpe     = nn.Embedding(block_size, n_embd)
        self.ln0     = RMSNorm(n_embd)
        self.layers  = [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

    def __call__(self, token_ids: mx.array) -> mx.array:
        B, T = token_ids.shape
        pos_ids = mx.arange(T)
        x = self.wte(token_ids) + self.wpe(pos_ids)
        x = self.ln0(x)
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(x)


def make_loss_fn(vocab_size: int):
    """loss 関数のクロージャを返す。main.mojo から nn.value_and_grad() に渡す。"""
    def loss_fn(model, tokens):
        logits = model(tokens[:, :-1])              # [1, seq, vocab]
        return mx.mean(nn.losses.cross_entropy(
            logits.reshape(-1, vocab_size),
            tokens[:, 1:].reshape(-1),
        ))
    return loss_fn


def count_params(model) -> int:
    """モデルの総パラメータ数を返す。"""
    return sum(p.size for _, p in tree_flatten(model.parameters()))

30.6.3. main.mojo(Mojo 版メインループ)

# microgpt_mlx.py を Mojo で書き直した版。
#
# Python 側(dataset.py / model.py)との分担:
#   dataset.py  … データ読み込み・トークナイザー(文字列処理は Python が得意)
#   model.py    … nn.Module クラス定義・loss 関数・パラメータカウント
#   main.mojo   … 設定・デバイス選択・学習ループ・推論(Mojo の型安全を活かす)
#
# microgpt_mlx.py との主な違い:
#   GPTConfig struct     ← Python にはない型付き設定管理
#   var device: String   ← 明示的な型アノテーション(MLX は CPU/GPU を自動選択)
#   Int / Float64        ← Mojo の整数・浮動小数点型
#   .__bool__()          ← PythonObject → Mojo Bool への明示的な変換
#   mx.eval() の扱い     ← Python interop 経由で Lazy Evaluation を確定

from std.python import Python, PythonObject


# ─────────────────────────────────────────────────────────────────────────────
# GPT ハイパーパラメータ(Mojo struct で型付き管理)
# ─────────────────────────────────────────────────────────────────────────────
struct GPTConfig:
    var vocab_size: Int
    var n_embd: Int
    var n_head: Int
    var n_layer: Int
    var block_size: Int

    def __init__(
        out self,
        vocab_size: Int,
        n_embd: Int = 16,
        n_head: Int = 4,
        n_layer: Int = 1,
        block_size: Int = 16,
    ):
        self.vocab_size = vocab_size
        self.n_embd     = n_embd
        self.n_head     = n_head
        self.n_layer    = n_layer
        self.block_size = block_size


def main() raises:
    # ── Python モジュールを import ──────────────────────────────────────────
    var sys    = Python.import_module("sys")
    var random = Python.import_module("random")
    var mx     = Python.import_module("mlx.core")
    var mlx_nn = Python.import_module("mlx.nn")
    var optim  = Python.import_module("mlx.optimizers")
    var np     = Python.import_module("numpy")
    sys.path.insert(0, ".")
    var dataset   = Python.import_module("dataset")
    var model_mod = Python.import_module("model")

    random.seed(42)
    mx.random.seed(42)

    # ── データセット・トークナイザー ────────────────────────────────────────
    var docs   = dataset.load_docs(42)
    var uchars = dataset.make_uchars(docs)
    var bos: Int        = len(uchars)
    var vocab_size: Int = bos + 1
    print("num docs:", len(docs))
    print("vocab size:", vocab_size)

    # ── ハイパーパラメータ(Mojo struct) ────────────────────────────────────
    var cfg = GPTConfig(vocab_size=vocab_size)

    # ── モデル初期化(MLX nn.Module を Python interop 経由で) ─────────────
    var model = model_mod.MicroGPT(
        cfg.vocab_size, cfg.n_embd, cfg.n_head, cfg.n_layer, cfg.block_size,
    )
    # MLX の Lazy Evaluation: パラメータを確定させる
    mx.eval(model.parameters())
    print("num params:", model_mod.count_params(model))

    # ── 損失関数と value_and_grad(MLX 固有のパターン) ──────────────────────
    # loss_fn は vocab_size を捕捉したクロージャ(model.py で定義)
    # microgpt_mlx.py: loss_and_grad = nn.value_and_grad(model, loss_fn)
    var loss_fn       = model_mod.make_loss_fn(cfg.vocab_size)
    var loss_and_grad = mlx_nn.value_and_grad(model, loss_fn)

    # ── Adam オプティマイザー ──────────────────────────────────────────────
    var learning_rate: Float64 = 0.01
    var betas = Python.evaluate("(0.85, 0.99)")
    var optimizer = optim.Adam(
        learning_rate=learning_rate, betas=betas, eps=1e-8,
    )

    # ── 学習ループ(Mojo の for + 型付き変数) ────────────────────────────
    var num_steps: Int = 1000

    for step in range(num_steps):
        var doc      = docs[step % len(docs)]
        var tok_list = dataset.encode(doc, uchars, bos)
        var n: Int   = min(cfg.block_size, len(tok_list) - 1)

        # tokens: [1, n+1] の MLX int32 配列
        var tokens = mx.array(tok_list[0 : n + 1], dtype=mx.int32).__getitem__(
            Python.evaluate("(None, slice(None))")
        )  # unsqueeze(0) に相当: [1, n+1]

        # 順伝播 + 逆伝播(MLX は lazy evaluation)
        var result = loss_and_grad(model, tokens)
        var loss   = result[0]
        var grads  = result[1]

        # 学習率の線形減衰
        var lr_t = learning_rate * (1.0 - Float64(step) / Float64(num_steps))
        optimizer.learning_rate = lr_t

        optimizer.update(model, grads)
        # mx.eval() で Lazy Evaluation を確定(microgpt_mlx.py と同じ)
        mx.eval(model.parameters(), optimizer.state, loss)

        print("step", step + 1, "/", num_steps, "| loss", loss.item())

    # ── 推論 ──────────────────────────────────────────────────────────────
    var temperature: Float64 = 0.5
    print("\n--- inference (MLX / Mojo) ---")

    for si in range(20):
        var generated = Python.list()
        generated.append(bos)
        var result_chars = Python.list()

        for _ in range(cfg.block_size):
            var nested = Python.list()
            nested.append(generated)
            var tok_in = mx.array(nested, dtype=mx.int32)

            var logits     = model(tok_in)
            # logits[0][-1] で最後のトークンの logits を取得
            var next_logit = logits[0][-1] / temperature
            var probs      = mx.softmax(next_logit, axis=-1)
            mx.eval(probs)

            # numpy 経由で確率的サンプリング
            var probs_np = np.array(probs.tolist(), dtype="float64")
            probs_np = probs_np / probs_np.sum()
            var token_id = np.random.choice(len(probs_np), p=probs_np)
            if token_id == bos:
                break
            generated.append(token_id)
            result_chars.append(uchars[token_id])

        print("sample", si + 1, ":", Python.str("").join(result_chars))

30.7. Python 版 vs Mojo 版:差分解説

30.7.1. 1. nn.value_and_grad の扱い

Python 版:

def loss_fn(model, tokens):
    logits = model(tokens[:, :-1])
    return mx.mean(nn.losses.cross_entropy(...))

loss_and_grad = nn.value_and_grad(model, loss_fn)

# 学習ループ内
loss, grads = loss_and_grad(model, tokens)

Mojo 版(model.py):

def make_loss_fn(vocab_size: int):
    """vocab_size を捕捉したクロージャを返す。Mojo から渡しやすくする。"""
    def loss_fn(model, tokens):
        ...
    return loss_fn

Mojo 版(main.mojo):

var loss_fn       = model_mod.make_loss_fn(cfg.vocab_size)
var loss_and_grad = mlx_nn.value_and_grad(model, loss_fn)

# 学習ループ内
var result = loss_and_grad(model, tokens)
var loss   = result[0]
var grads  = result[1]

Python 版では loss_fn を同じスコープで定義して直接渡せますが、Mojo では Python 関数を直接定義できません。make_loss_fn() でクロージャを生成して PythonObject として受け取り、mlx_nn.value_and_grad() に渡します。

Python 版の loss, grads = loss_and_grad(...) はアンパック代入ですが、Mojo ではタプルのアンパックができないため result[0], result[1] で個別に取得します。


30.7.2. 2. mx.eval() の呼び出し

Python 版:

optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss)

Mojo 版:

optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss)

この部分は Python 版と Mojo 版でほぼ同じです。mx.eval() は可変長引数を受け取る Python 関数で、Mojo からも同様に呼べます。MLX の Lazy Evaluation(遅延評価)を「確定」させる重要なステップです。


30.7.3. 3. tokens テンソルの生成

Python 版:

tokens = mx.array(tokens_list[:n + 1], dtype=mx.int32)[None, :]

Mojo 版:

var tokens = mx.array(tok_list[0 : n + 1], dtype=mx.int32).__getitem__(
    Python.evaluate("(None, slice(None))")
)

Python の [None, :] は「先頭に次元を追加して全要素を取る」スライスです。Mojo ではこのスライス記法が使えないため、Python.evaluate() でスライスオブジェクトを生成して __getitem__() に渡します。


30.7.4. 4. 学習率の更新

Python 版:

optimizer.learning_rate = lr_t

Mojo 版:

var lr_t = learning_rate * (1.0 - Float64(step) / Float64(num_steps))
optimizer.learning_rate = lr_t

属性への代入は Python 版と同じ構文で書けます。ただし lr_t は Mojo の Float64 であり、MLX の Python コードは Mojo の Float64 を Python の float として受け取ります。


30.7.5. 5. 乱数サンプリング

Python 版:

probs_np = np.array(probs.tolist(), dtype=np.float64)
probs_np = probs_np / probs_np.sum()
token_id = int(np.random.choice(len(probs_np), p=probs_np))

Mojo 版:

var probs_np = np.array(probs.tolist(), dtype="float64")
probs_np = probs_np / probs_np.sum()
var token_id = np.random.choice(len(probs_np), p=probs_np)
if token_id == bos:
    break

Mojo 版では token_idPythonObject のまま扱い、bos(Mojo Int)との比較も Python の __eq__ 経由で行います。Python 版の int(...) による明示的な変換は不要です。


30.8. 実行方法

# Python 版
uv run python src/part3/microgpt_mlx.py

# Mojo 版(ディレクトリに移動してから実行)
cd src/part3/microgpt_mlx_mojo
uv run mojo run main.mojo

30.9. 全実装の比較まとめ

本書で実装してきた microgpt の全バリエーションを整理します。

実装

学習エンジン

推論エンジン

Apple Silicon 対応

主な学び

microgpt.py

Python Value(スカラー)

Python Value

なし

autograd の仕組み

microgpt_mojo

Mojo Tape(スカラー)

Mojo Tape

なし

Mojo の型安全

microgpt_mojo_max

Mojo Tape

MAX Graph

CPU のみ

Mojo + Python interop

microgpt_torch.py

PyTorch autograd

PyTorch

MPS(GPU)

フレームワーク活用

microgpt_torch_mojo

PyTorch autograd

PyTorch

MPS(GPU)

Mojo から PyTorch を操る

microgpt_mlx.py

MLX autograd

MLX

Unified Memory

Lazy Evaluation

microgpt_mlx_mojo

MLX autograd

MLX

Unified Memory

Mojo から MLX を操る

Mojo 版が Python 版に対して示すもの:

  • struct GPTConfig による型付きハイパーパラメータ管理

  • var device: String などの明示的な型アノテーション

  • .__bool__(), Float64(step) による明示的な型変換

  • Python.evaluate() による Python オブジェクトの生成

  • Python 関数(nn.Moduleloss_fn)は Python 側に残し、ループ制御は Mojo で行う役割分離


30.10. まとめ

  • MLX の遅延評価パターン(loss_and_grad + mx.eval())は、Mojo から PythonObject 経由でそのまま使える

  • Python のアンパック代入(a, b = fn())は Mojo では使えず、result[0], result[1] で代替する

  • Mojo の struct を使うとハイパーパラメータを型安全に管理でき、コンパイル時にミスを検出できる

  • Python が得意な部分(文字列処理、nn.Module 定義)は Python に残し、Mojo はループ制御と型管理を担う