30. microgpt を MLX で書き直す
30.1. この章で学ぶこと
Apple Silicon 専用の MLX フレームワークで microgpt.py を書き直す方法
MLX の Lazy Evaluation(遅延評価) と
mx.eval()の役割を理解する同じロジックを Python 版と Mojo 版で書き比べ、MLX 固有の interop パターンを習得する
30.2. なぜ MLX か
MLX が本書で最速(3.0 秒)を出した理由は 2 つです。
Unified Memory — Apple Silicon では CPU と GPU がメモリを共有し、データコピーが不要。PyTorch MPS は CPU-GPU 間コピーが発生するがMLX は不要。
Lazy Evaluation(遅延評価) — すべての演算がデフォルトで遅延実行され、
mx.eval()で確定するまでグラフ全体を最適化できる。
MLX は Apple が開発した、Apple Silicon(M1/M2/M3)専用の機械学習フレームワークです。
特徴 |
PyTorch(MPS) |
MLX |
|---|---|---|
CPU/GPU メモリ |
別々(コピーが発生) |
統合(コピー不要) |
API スタイル |
オブジェクト指向 |
NumPy に近い関数型 |
遅延評価 |
一部のみ |
全演算がデフォルトで遅延 |
対応ハードウェア |
汎用 + MPS |
Apple Silicon 専用 |
30.3. microgpt.py との対応
microgpt.py |
MLX 版 |
|---|---|
|
|
|
|
|
カスタム |
|
|
|
|
|
|
Adam バッファ手書き |
|
手書きの学習率減衰 |
|
30.4. ディレクトリ構成
30.4.1. Python 版(microgpt_mlx.py)
単一ファイルにすべてを記述したシンプルな実装です。
src/part3/
└── microgpt_mlx.py ← データ・モデル・学習・推論をすべて含む
30.4.2. Mojo 版(microgpt_mlx_mojo/)
Python と Mojo の役割を分離した構成です。
src/part3/microgpt_mlx_mojo/
├── dataset.py ← データ読み込み・トークナイザー
├── model.py ← MLX nn.Module・loss 関数・パラメータカウント
└── main.mojo ← 設定・学習ループ・推論(Mojo の型安全を活かす)
30.5. Python 版ソースコード
"""
microgpt.py を MLX で書き直した版。
学習と推論を MLX の autograd + Apple Silicon Unified Memory で行う。
microgpt.py との対応:
class Value → mx.array(不要:autograd 内蔵)
linear(), rmsnorm() → nn.Linear, nn.RMSNorm
gpt() 関数 → MicroGPT(nn.Module).__call__()
手書き Adam → optim.Adam
1 文書ずつ学習 → 1 文書(バッチサイズ 1)のまま(microgpt.py と比較しやすくするため)
"""
import os
import math
import random
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
random.seed(42)
mx.random.seed(42)
# -----------------------------------------------------------------------------
# データセット(microgpt.py と同じ)
# -----------------------------------------------------------------------------
if not os.path.exists("input.txt"):
import urllib.request
names_url = "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt"
urllib.request.urlretrieve(names_url, "input.txt")
docs = [line.strip() for line in open("input.txt") if line.strip()]
random.shuffle(docs)
print(f"num docs: {len(docs)}")
# -----------------------------------------------------------------------------
# トークナイザー(microgpt.py と同じ)
# -----------------------------------------------------------------------------
uchars = sorted(set("".join(docs)))
BOS = len(uchars)
vocab_size = len(uchars) + 1
print(f"vocab size: {vocab_size}")
def encode(doc: str) -> list[int]:
return [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
# -----------------------------------------------------------------------------
# モデル定義
# -----------------------------------------------------------------------------
class RMSNorm(nn.Module):
"""microgpt.py の rmsnorm() に対応。nn.RMSNorm はスケールなし版を自前実装。"""
def __init__(self, n_embd: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
# x: [..., n_embd]
ms = mx.mean(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(ms + self.eps)
class TransformerBlock(nn.Module):
"""1 層の Attention + MLP ブロック。microgpt.py の gpt() 内ループ 1 回分に対応。"""
def __init__(self, n_embd: int, n_head: int):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.ln1 = RMSNorm(n_embd)
self.wq = nn.Linear(n_embd, n_embd, bias=False)
self.wk = nn.Linear(n_embd, n_embd, bias=False)
self.wv = nn.Linear(n_embd, n_embd, bias=False)
self.wo = nn.Linear(n_embd, n_embd, bias=False)
self.ln2 = RMSNorm(n_embd)
self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False)
self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=False)
def __call__(self, x: mx.array) -> mx.array:
# x: [batch, seq, n_embd]
B, T, C = x.shape
H, D = self.n_head, self.head_dim
# Attention(因果マスク付き)
xn = self.ln1(x)
q = self.wq(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3) # [B, H, T, D]
k = self.wk(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
v = self.wv(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
scale = 1.0 / math.sqrt(D)
attn = (q @ k.transpose(0, 1, 3, 2)) * scale # [B, H, T, T]
# 因果マスク: 未来のトークンを参照しないように
mask = mx.triu(mx.ones((T, T)), k=1).astype(mx.bool_)
attn = mx.where(mask, mx.array(float("-inf")), attn)
attn = mx.softmax(attn, axis=-1)
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, C) # [B, T, C]
out = self.wo(out)
x = x + out # 残差接続
# MLP
h = nn.relu(self.fc1(self.ln2(x)))
x = x + self.fc2(h) # 残差接続
return x
class MicroGPT(nn.Module):
"""microgpt.py の gpt() 関数 + state_dict に対応する nn.Module。"""
def __init__(self, vocab_size: int, n_embd: int, n_head: int, n_layer: int, block_size: int):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(block_size, n_embd)
self.ln0 = RMSNorm(n_embd)
self.layers = [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
def __call__(self, token_ids: mx.array) -> mx.array:
# token_ids: [batch, seq] → logits: [batch, seq, vocab]
B, T = token_ids.shape
pos_ids = mx.arange(T)
x = self.wte(token_ids) + self.wpe(pos_ids) # [B, T, n_embd]
x = self.ln0(x)
for layer in self.layers:
x = layer(x)
return self.lm_head(x) # [B, T, vocab]
# -----------------------------------------------------------------------------
# モデル初期化
# -----------------------------------------------------------------------------
n_layer, n_embd, block_size, n_head = 1, 16, 16, 4
model = MicroGPT(vocab_size, n_embd, n_head, n_layer, block_size)
mx.eval(model.parameters())
num_params = sum(p.size for _, p in tree_flatten(model.parameters()))
print(f"num params: {num_params}")
# -----------------------------------------------------------------------------
# 学習ループ
# -----------------------------------------------------------------------------
# microgpt.py と同じハイパーパラメータ
learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
optimizer = optim.Adam(learning_rate=learning_rate, betas=(beta1, beta2), eps=eps_adam)
def loss_fn(model, tokens):
"""tokens: [1, seq+1] の整数配列"""
logits = model(tokens[:, :-1]) # [1, seq, vocab]
targets = tokens[:, 1:] # [1, seq]
# cross entropy: logits [N, vocab], targets [N]
return mx.mean(nn.losses.cross_entropy(
logits.reshape(-1, vocab_size),
targets.reshape(-1),
))
# nn.value_and_grad: loss と勾配を同時に計算する MLX のパターン
loss_and_grad = nn.value_and_grad(model, loss_fn)
num_steps = 1000
for step in range(num_steps):
doc = docs[step % len(docs)]
tokens_list = encode(doc)
n = min(block_size, len(tokens_list) - 1)
tokens = mx.array(tokens_list[:n + 1], dtype=mx.int32)[None, :] # [1, n+1]
# 順伝播 + 逆伝播(MLX は lazy evaluation → mx.eval で確定)
loss, grads = loss_and_grad(model, tokens)
# 学習率の線形減衰(microgpt.py と同じ)
lr_t = learning_rate * (1.0 - step / num_steps)
optimizer.learning_rate = lr_t
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss) # lazy eval を確定
print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.item():.4f}")
# -----------------------------------------------------------------------------
# 推論
# -----------------------------------------------------------------------------
temperature = 0.5
print("\n--- inference (MLX) ---")
for sample_idx in range(20):
token_id = BOS
result = []
generated = [BOS]
for pos_id in range(block_size):
tokens_in = mx.array([generated], dtype=mx.int32)
logits = model(tokens_in) # [1, seq, vocab]
next_logit = logits[0, -1, :] / temperature # 最後の位置のlogits
probs = mx.softmax(next_logit, axis=-1)
mx.eval(probs)
probs_np = np.array(probs.tolist(), dtype=np.float64)
probs_np = probs_np / probs_np.sum() # 数値誤差を正規化
token_id = int(np.random.choice(len(probs_np), p=probs_np))
if token_id == BOS:
break
generated.append(token_id)
result.append(uchars[token_id])
print(f"sample {sample_idx+1:2d}: {''.join(result)}")
30.6. Mojo 版ソースコード
30.6.1. dataset.py(データ処理)
"""
データセットとトークナイザー。
main.mojo から Python interop 経由で呼ばれる。
microgpt_mlx.py のトップレベルコードを関数に切り出した版。
"""
import os
import random
def load_docs(seed: int = 42) -> list[str]:
"""input.txt を読み込み、シャッフルして返す。"""
if not os.path.exists("input.txt"):
import urllib.request
urllib.request.urlretrieve(
"https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt",
"input.txt",
)
docs = [line.strip() for line in open("input.txt") if line.strip()]
random.seed(seed)
random.shuffle(docs)
return docs
def make_uchars(docs: list[str]) -> list[str]:
"""ユニーク文字のソート済みリストを返す。"""
return sorted(set("".join(docs)))
def encode(doc: str, uchars: list[str], bos: int) -> list[int]:
"""文書を BOS 付きトークン列に変換する。"""
return [bos] + [uchars.index(ch) for ch in doc] + [bos]
30.6.2. model.py(MLX モデル定義)
"""
MLX モデル定義。main.mojo から Python interop 経由で使われる。
microgpt_mlx.py の nn.Module クラスと loss 関数を切り出した版。
トップレベルの学習ループ・推論は main.mojo 側で行う。
"""
import math
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
class RMSNorm(nn.Module):
"""microgpt.py の rmsnorm() に対応。スケールパラメータなし版。"""
def __init__(self, n_embd: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
ms = mx.mean(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(ms + self.eps)
class TransformerBlock(nn.Module):
"""1 層の Attention + MLP ブロック。"""
def __init__(self, n_embd: int, n_head: int):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.ln1 = RMSNorm(n_embd)
self.wq = nn.Linear(n_embd, n_embd, bias=False)
self.wk = nn.Linear(n_embd, n_embd, bias=False)
self.wv = nn.Linear(n_embd, n_embd, bias=False)
self.wo = nn.Linear(n_embd, n_embd, bias=False)
self.ln2 = RMSNorm(n_embd)
self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=False)
self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=False)
def __call__(self, x: mx.array) -> mx.array:
B, T, C = x.shape
H, D = self.n_head, self.head_dim
xn = self.ln1(x)
q = self.wq(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
k = self.wk(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
v = self.wv(xn).reshape(B, T, H, D).transpose(0, 2, 1, 3)
scale = 1.0 / math.sqrt(D)
attn = (q @ k.transpose(0, 1, 3, 2)) * scale
mask = mx.triu(mx.ones((T, T)), k=1).astype(mx.bool_)
attn = mx.where(mask, mx.array(float("-inf")), attn)
attn = mx.softmax(attn, axis=-1)
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
out = self.wo(out)
x = x + out
h = nn.relu(self.fc1(self.ln2(x)))
x = x + self.fc2(h)
return x
class MicroGPT(nn.Module):
"""GPT モデル本体。main.mojo から Python interop 経由でインスタンス化される。"""
def __init__(
self,
vocab_size: int,
n_embd: int,
n_head: int,
n_layer: int,
block_size: int,
):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(block_size, n_embd)
self.ln0 = RMSNorm(n_embd)
self.layers = [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
def __call__(self, token_ids: mx.array) -> mx.array:
B, T = token_ids.shape
pos_ids = mx.arange(T)
x = self.wte(token_ids) + self.wpe(pos_ids)
x = self.ln0(x)
for layer in self.layers:
x = layer(x)
return self.lm_head(x)
def make_loss_fn(vocab_size: int):
"""loss 関数のクロージャを返す。main.mojo から nn.value_and_grad() に渡す。"""
def loss_fn(model, tokens):
logits = model(tokens[:, :-1]) # [1, seq, vocab]
return mx.mean(nn.losses.cross_entropy(
logits.reshape(-1, vocab_size),
tokens[:, 1:].reshape(-1),
))
return loss_fn
def count_params(model) -> int:
"""モデルの総パラメータ数を返す。"""
return sum(p.size for _, p in tree_flatten(model.parameters()))
30.6.3. main.mojo(Mojo 版メインループ)
# microgpt_mlx.py を Mojo で書き直した版。
#
# Python 側(dataset.py / model.py)との分担:
# dataset.py … データ読み込み・トークナイザー(文字列処理は Python が得意)
# model.py … nn.Module クラス定義・loss 関数・パラメータカウント
# main.mojo … 設定・デバイス選択・学習ループ・推論(Mojo の型安全を活かす)
#
# microgpt_mlx.py との主な違い:
# GPTConfig struct ← Python にはない型付き設定管理
# var device: String ← 明示的な型アノテーション(MLX は CPU/GPU を自動選択)
# Int / Float64 ← Mojo の整数・浮動小数点型
# .__bool__() ← PythonObject → Mojo Bool への明示的な変換
# mx.eval() の扱い ← Python interop 経由で Lazy Evaluation を確定
from std.python import Python, PythonObject
# ─────────────────────────────────────────────────────────────────────────────
# GPT ハイパーパラメータ(Mojo struct で型付き管理)
# ─────────────────────────────────────────────────────────────────────────────
struct GPTConfig:
var vocab_size: Int
var n_embd: Int
var n_head: Int
var n_layer: Int
var block_size: Int
def __init__(
out self,
vocab_size: Int,
n_embd: Int = 16,
n_head: Int = 4,
n_layer: Int = 1,
block_size: Int = 16,
):
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_head = n_head
self.n_layer = n_layer
self.block_size = block_size
def main() raises:
# ── Python モジュールを import ──────────────────────────────────────────
var sys = Python.import_module("sys")
var random = Python.import_module("random")
var mx = Python.import_module("mlx.core")
var mlx_nn = Python.import_module("mlx.nn")
var optim = Python.import_module("mlx.optimizers")
var np = Python.import_module("numpy")
sys.path.insert(0, ".")
var dataset = Python.import_module("dataset")
var model_mod = Python.import_module("model")
random.seed(42)
mx.random.seed(42)
# ── データセット・トークナイザー ────────────────────────────────────────
var docs = dataset.load_docs(42)
var uchars = dataset.make_uchars(docs)
var bos: Int = len(uchars)
var vocab_size: Int = bos + 1
print("num docs:", len(docs))
print("vocab size:", vocab_size)
# ── ハイパーパラメータ(Mojo struct) ────────────────────────────────────
var cfg = GPTConfig(vocab_size=vocab_size)
# ── モデル初期化(MLX nn.Module を Python interop 経由で) ─────────────
var model = model_mod.MicroGPT(
cfg.vocab_size, cfg.n_embd, cfg.n_head, cfg.n_layer, cfg.block_size,
)
# MLX の Lazy Evaluation: パラメータを確定させる
mx.eval(model.parameters())
print("num params:", model_mod.count_params(model))
# ── 損失関数と value_and_grad(MLX 固有のパターン) ──────────────────────
# loss_fn は vocab_size を捕捉したクロージャ(model.py で定義)
# microgpt_mlx.py: loss_and_grad = nn.value_and_grad(model, loss_fn)
var loss_fn = model_mod.make_loss_fn(cfg.vocab_size)
var loss_and_grad = mlx_nn.value_and_grad(model, loss_fn)
# ── Adam オプティマイザー ──────────────────────────────────────────────
var learning_rate: Float64 = 0.01
var betas = Python.evaluate("(0.85, 0.99)")
var optimizer = optim.Adam(
learning_rate=learning_rate, betas=betas, eps=1e-8,
)
# ── 学習ループ(Mojo の for + 型付き変数) ────────────────────────────
var num_steps: Int = 1000
for step in range(num_steps):
var doc = docs[step % len(docs)]
var tok_list = dataset.encode(doc, uchars, bos)
var n: Int = min(cfg.block_size, len(tok_list) - 1)
# tokens: [1, n+1] の MLX int32 配列
var tokens = mx.array(tok_list[0 : n + 1], dtype=mx.int32).__getitem__(
Python.evaluate("(None, slice(None))")
) # unsqueeze(0) に相当: [1, n+1]
# 順伝播 + 逆伝播(MLX は lazy evaluation)
var result = loss_and_grad(model, tokens)
var loss = result[0]
var grads = result[1]
# 学習率の線形減衰
var lr_t = learning_rate * (1.0 - Float64(step) / Float64(num_steps))
optimizer.learning_rate = lr_t
optimizer.update(model, grads)
# mx.eval() で Lazy Evaluation を確定(microgpt_mlx.py と同じ)
mx.eval(model.parameters(), optimizer.state, loss)
print("step", step + 1, "/", num_steps, "| loss", loss.item())
# ── 推論 ──────────────────────────────────────────────────────────────
var temperature: Float64 = 0.5
print("\n--- inference (MLX / Mojo) ---")
for si in range(20):
var generated = Python.list()
generated.append(bos)
var result_chars = Python.list()
for _ in range(cfg.block_size):
var nested = Python.list()
nested.append(generated)
var tok_in = mx.array(nested, dtype=mx.int32)
var logits = model(tok_in)
# logits[0][-1] で最後のトークンの logits を取得
var next_logit = logits[0][-1] / temperature
var probs = mx.softmax(next_logit, axis=-1)
mx.eval(probs)
# numpy 経由で確率的サンプリング
var probs_np = np.array(probs.tolist(), dtype="float64")
probs_np = probs_np / probs_np.sum()
var token_id = np.random.choice(len(probs_np), p=probs_np)
if token_id == bos:
break
generated.append(token_id)
result_chars.append(uchars[token_id])
print("sample", si + 1, ":", Python.str("").join(result_chars))
30.7. Python 版 vs Mojo 版:差分解説
30.7.1. 1. nn.value_and_grad の扱い
Python 版:
def loss_fn(model, tokens):
logits = model(tokens[:, :-1])
return mx.mean(nn.losses.cross_entropy(...))
loss_and_grad = nn.value_and_grad(model, loss_fn)
# 学習ループ内
loss, grads = loss_and_grad(model, tokens)
Mojo 版(model.py):
def make_loss_fn(vocab_size: int):
"""vocab_size を捕捉したクロージャを返す。Mojo から渡しやすくする。"""
def loss_fn(model, tokens):
...
return loss_fn
Mojo 版(main.mojo):
var loss_fn = model_mod.make_loss_fn(cfg.vocab_size)
var loss_and_grad = mlx_nn.value_and_grad(model, loss_fn)
# 学習ループ内
var result = loss_and_grad(model, tokens)
var loss = result[0]
var grads = result[1]
Python 版では loss_fn を同じスコープで定義して直接渡せますが、Mojo では Python 関数を直接定義できません。make_loss_fn() でクロージャを生成して PythonObject として受け取り、mlx_nn.value_and_grad() に渡します。
Python 版の loss, grads = loss_and_grad(...) はアンパック代入ですが、Mojo ではタプルのアンパックができないため result[0], result[1] で個別に取得します。
30.7.2. 2. mx.eval() の呼び出し
Python 版:
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss)
Mojo 版:
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss)
この部分は Python 版と Mojo 版でほぼ同じです。mx.eval() は可変長引数を受け取る Python 関数で、Mojo からも同様に呼べます。MLX の Lazy Evaluation(遅延評価)を「確定」させる重要なステップです。
30.7.3. 3. tokens テンソルの生成
Python 版:
tokens = mx.array(tokens_list[:n + 1], dtype=mx.int32)[None, :]
Mojo 版:
var tokens = mx.array(tok_list[0 : n + 1], dtype=mx.int32).__getitem__(
Python.evaluate("(None, slice(None))")
)
Python の [None, :] は「先頭に次元を追加して全要素を取る」スライスです。Mojo ではこのスライス記法が使えないため、Python.evaluate() でスライスオブジェクトを生成して __getitem__() に渡します。
30.7.4. 4. 学習率の更新
Python 版:
optimizer.learning_rate = lr_t
Mojo 版:
var lr_t = learning_rate * (1.0 - Float64(step) / Float64(num_steps))
optimizer.learning_rate = lr_t
属性への代入は Python 版と同じ構文で書けます。ただし lr_t は Mojo の Float64 であり、MLX の Python コードは Mojo の Float64 を Python の float として受け取ります。
30.7.5. 5. 乱数サンプリング
Python 版:
probs_np = np.array(probs.tolist(), dtype=np.float64)
probs_np = probs_np / probs_np.sum()
token_id = int(np.random.choice(len(probs_np), p=probs_np))
Mojo 版:
var probs_np = np.array(probs.tolist(), dtype="float64")
probs_np = probs_np / probs_np.sum()
var token_id = np.random.choice(len(probs_np), p=probs_np)
if token_id == bos:
break
Mojo 版では token_id を PythonObject のまま扱い、bos(Mojo Int)との比較も Python の __eq__ 経由で行います。Python 版の int(...) による明示的な変換は不要です。
30.8. 実行方法
# Python 版
uv run python src/part3/microgpt_mlx.py
# Mojo 版(ディレクトリに移動してから実行)
cd src/part3/microgpt_mlx_mojo
uv run mojo run main.mojo
30.9. 全実装の比較まとめ
本書で実装してきた microgpt の全バリエーションを整理します。
実装 |
学習エンジン |
推論エンジン |
Apple Silicon 対応 |
主な学び |
|---|---|---|---|---|
microgpt.py |
Python Value(スカラー) |
Python Value |
なし |
autograd の仕組み |
microgpt_mojo |
Mojo Tape(スカラー) |
Mojo Tape |
なし |
Mojo の型安全 |
microgpt_mojo_max |
Mojo Tape |
MAX Graph |
CPU のみ |
Mojo + Python interop |
microgpt_torch.py |
PyTorch autograd |
PyTorch |
MPS(GPU) |
フレームワーク活用 |
microgpt_torch_mojo |
PyTorch autograd |
PyTorch |
MPS(GPU) |
Mojo から PyTorch を操る |
microgpt_mlx.py |
MLX autograd |
MLX |
Unified Memory |
Lazy Evaluation |
microgpt_mlx_mojo |
MLX autograd |
MLX |
Unified Memory |
Mojo から MLX を操る |
Mojo 版が Python 版に対して示すもの:
struct GPTConfigによる型付きハイパーパラメータ管理var device: Stringなどの明示的な型アノテーション.__bool__(),Float64(step)による明示的な型変換Python.evaluate()による Python オブジェクトの生成Python 関数(
nn.Module、loss_fn)は Python 側に残し、ループ制御は Mojo で行う役割分離
30.10. まとめ
MLX の遅延評価パターン(
loss_and_grad+mx.eval())は、Mojo からPythonObject経由でそのまま使えるPython のアンパック代入(
a, b = fn())は Mojo では使えず、result[0],result[1]で代替するMojo の
structを使うとハイパーパラメータを型安全に管理でき、コンパイル時にミスを検出できるPython が得意な部分(文字列処理、
nn.Module定義)は Python に残し、Mojo はループ制御と型管理を担う