(microgpt MLX 版)= # microgpt を MLX で書き直す ## この章で学ぶこと - Apple Silicon 専用の **MLX** フレームワークで microgpt.py を書き直す方法 - MLX の **Lazy Evaluation(遅延評価)** と `mx.eval()` の役割を理解する - 同じロジックを **Python 版**と **Mojo 版**で書き比べ、MLX 固有の interop パターンを習得する --- ## なぜ MLX か MLX が本書で最速(3.0 秒)を出した理由は 2 つです。 - **Unified Memory** — Apple Silicon では CPU と GPU がメモリを共有し、データコピーが不要。PyTorch MPS は CPU-GPU 間コピーが発生するがMLX は不要。 - **Lazy Evaluation(遅延評価)** — すべての演算がデフォルトで遅延実行され、`mx.eval()` で確定するまでグラフ全体を最適化できる。 MLX は Apple が開発した、Apple Silicon(M1/M2/M3)専用の機械学習フレームワークです。 | 特徴 | PyTorch(MPS) | MLX | |------|--------------|-----| | CPU/GPU メモリ | 別々(コピーが発生) | 統合(コピー不要) | | API スタイル | オブジェクト指向 | NumPy に近い関数型 | | 遅延評価 | 一部のみ | 全演算がデフォルトで遅延 | | 対応ハードウェア | 汎用 + MPS | Apple Silicon 専用 | --- ## microgpt.py との対応 | microgpt.py | MLX 版 | |-------------|--------| | `class Value` | `mx.array`(autograd 内蔵なので不要) | | `linear(x, w)` | `nn.Linear(in, out, bias=False)` | | `rmsnorm(x)` | カスタム `RMSNorm(nn.Module)` | | `softmax(logits)` | `mx.softmax(x, axis=-1)` | | `gpt(token_id, ...)` | `MicroGPT.__call__(tokens)` | | `loss.backward()` | `nn.value_and_grad()` で loss と勾配を同時計算 | | Adam バッファ手書き | `optim.Adam` | | 手書きの学習率減衰 | `optimizer.learning_rate = lr_t` | --- ## ディレクトリ構成 ### Python 版(`microgpt_mlx.py`) 単一ファイルにすべてを記述したシンプルな実装です。 ``` src/part3/ └── microgpt_mlx.py ← データ・モデル・学習・推論をすべて含む ``` ### Mojo 版(`microgpt_mlx_mojo/`) Python と Mojo の役割を分離した構成です。 ``` src/part3/microgpt_mlx_mojo/ ├── dataset.py ← データ読み込み・トークナイザー ├── model.py ← MLX nn.Module・loss 関数・パラメータカウント └── main.mojo ← 設定・学習ループ・推論(Mojo の型安全を活かす) ``` --- ## Python 版ソースコード ```{literalinclude} ../../../src/part3/microgpt_mlx.py :language: python ``` --- ## Mojo 版ソースコード ### dataset.py(データ処理) ```{literalinclude} ../../../src/part3/microgpt_mlx_mojo/dataset.py :language: python ``` ### model.py(MLX モデル定義) ```{literalinclude} ../../../src/part3/microgpt_mlx_mojo/model.py :language: python ``` ### main.mojo(Mojo 版メインループ) ```{literalinclude} ../../../src/part3/microgpt_mlx_mojo/main.mojo :language: mojo ``` --- ## Python 版 vs Mojo 版:差分解説 ### 1. nn.value_and_grad の扱い **Python 版:** ```python def loss_fn(model, tokens): logits = model(tokens[:, :-1]) return mx.mean(nn.losses.cross_entropy(...)) loss_and_grad = nn.value_and_grad(model, loss_fn) # 学習ループ内 loss, grads = loss_and_grad(model, tokens) ``` **Mojo 版(model.py):** ```python def make_loss_fn(vocab_size: int): """vocab_size を捕捉したクロージャを返す。Mojo から渡しやすくする。""" def loss_fn(model, tokens): ... return loss_fn ``` **Mojo 版(main.mojo):** ```mojo var loss_fn = model_mod.make_loss_fn(cfg.vocab_size) var loss_and_grad = mlx_nn.value_and_grad(model, loss_fn) # 学習ループ内 var result = loss_and_grad(model, tokens) var loss = result[0] var grads = result[1] ``` Python 版では `loss_fn` を同じスコープで定義して直接渡せますが、Mojo では Python 関数を直接定義できません。`make_loss_fn()` でクロージャを生成して `PythonObject` として受け取り、`mlx_nn.value_and_grad()` に渡します。 Python 版の `loss, grads = loss_and_grad(...)` はアンパック代入ですが、Mojo ではタプルのアンパックができないため `result[0]`, `result[1]` で個別に取得します。 --- ### 2. mx.eval() の呼び出し **Python 版:** ```python optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state, loss) ``` **Mojo 版:** ```mojo optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state, loss) ``` この部分は Python 版と Mojo 版でほぼ同じです。`mx.eval()` は可変長引数を受け取る Python 関数で、Mojo からも同様に呼べます。MLX の Lazy Evaluation(遅延評価)を「確定」させる重要なステップです。 --- ### 3. tokens テンソルの生成 **Python 版:** ```python tokens = mx.array(tokens_list[:n + 1], dtype=mx.int32)[None, :] ``` **Mojo 版:** ```mojo var tokens = mx.array(tok_list[0 : n + 1], dtype=mx.int32).__getitem__( Python.evaluate("(None, slice(None))") ) ``` Python の `[None, :]` は「先頭に次元を追加して全要素を取る」スライスです。Mojo ではこのスライス記法が使えないため、`Python.evaluate()` でスライスオブジェクトを生成して `__getitem__()` に渡します。 --- ### 4. 学習率の更新 **Python 版:** ```python optimizer.learning_rate = lr_t ``` **Mojo 版:** ```mojo var lr_t = learning_rate * (1.0 - Float64(step) / Float64(num_steps)) optimizer.learning_rate = lr_t ``` 属性への代入は Python 版と同じ構文で書けます。ただし `lr_t` は Mojo の `Float64` であり、MLX の Python コードは Mojo の `Float64` を Python の `float` として受け取ります。 --- ### 5. 乱数サンプリング **Python 版:** ```python probs_np = np.array(probs.tolist(), dtype=np.float64) probs_np = probs_np / probs_np.sum() token_id = int(np.random.choice(len(probs_np), p=probs_np)) ``` **Mojo 版:** ```mojo var probs_np = np.array(probs.tolist(), dtype="float64") probs_np = probs_np / probs_np.sum() var token_id = np.random.choice(len(probs_np), p=probs_np) if token_id == bos: break ``` Mojo 版では `token_id` を `PythonObject` のまま扱い、`bos`(Mojo `Int`)との比較も Python の `__eq__` 経由で行います。Python 版の `int(...)` による明示的な変換は不要です。 --- ## 実行方法 ```bash # Python 版 uv run python src/part3/microgpt_mlx.py # Mojo 版(ディレクトリに移動してから実行) cd src/part3/microgpt_mlx_mojo uv run mojo run main.mojo ``` --- ## 全実装の比較まとめ 本書で実装してきた microgpt の全バリエーションを整理します。 | 実装 | 学習エンジン | 推論エンジン | Apple Silicon 対応 | 主な学び | |------|------------|------------|-------------------|---------| | microgpt.py | Python Value(スカラー) | Python Value | なし | autograd の仕組み | | microgpt_mojo | Mojo Tape(スカラー) | Mojo Tape | なし | Mojo の型安全 | | microgpt_mojo_max | Mojo Tape | MAX Graph | CPU のみ | Mojo + Python interop | | microgpt_torch.py | PyTorch autograd | PyTorch | MPS(GPU) | フレームワーク活用 | | microgpt_torch_mojo | PyTorch autograd | PyTorch | MPS(GPU) | Mojo から PyTorch を操る | | microgpt_mlx.py | MLX autograd | MLX | Unified Memory | Lazy Evaluation | | microgpt_mlx_mojo | MLX autograd | MLX | Unified Memory | Mojo から MLX を操る | **Mojo 版が Python 版に対して示すもの:** - `struct GPTConfig` による型付きハイパーパラメータ管理 - `var device: String` などの明示的な型アノテーション - `.__bool__()`, `Float64(step)` による明示的な型変換 - `Python.evaluate()` による Python オブジェクトの生成 - Python 関数(`nn.Module`、`loss_fn`)は Python 側に残し、ループ制御は Mojo で行う役割分離 --- ## まとめ - MLX の遅延評価パターン(`loss_and_grad` + `mx.eval()`)は、Mojo から `PythonObject` 経由でそのまま使える - Python のアンパック代入(`a, b = fn()`)は Mojo では使えず、`result[0]`, `result[1]` で代替する - Mojo の `struct` を使うとハイパーパラメータを型安全に管理でき、コンパイル時にミスを検出できる - Python が得意な部分(文字列処理、`nn.Module` 定義)は Python に残し、Mojo はループ制御と型管理を担う