(microgpt PyTorch 版)= # microgpt を PyTorch で書き直す ## この章で学ぶこと - PyTorch の `autograd` + `nn.Module` を使って microgpt.py を書き直す方法 - Apple Silicon の MPS(GPU)バックエンドで学習・推論を高速化する方法 - 同じロジックを **Python 版**と **Mojo 版**で書き比べ、Python interop の境界を理解する --- ## なぜ PyTorch か 前章(MAX Graph)は**推論専用**でした。学習も含めて高速化したい場合、 PyTorch の `autograd` エンジンが有力な選択肢です。 | 実装 | 学習 | 推論 | Apple Silicon | |------|------|------|---------------| | microgpt.py | Python スカラーループ | Python スカラーループ | 非対応 | | microgpt_mojo | Mojo スカラーループ | Mojo スカラーループ | 非対応 | | microgpt_mojo_max | Mojo Tape | MAX Graph | CPU のみ | | **microgpt_torch.py** | **PyTorch + MPS** | **PyTorch + MPS** | **MPS(GPU)対応** | | **microgpt_torch_mojo** | **PyTorch + MPS** | **PyTorch + MPS** | **MPS(GPU)対応** | --- ## microgpt.py との対応 | microgpt.py | PyTorch 版 | |-------------|-----------| | `class Value` | `torch.Tensor`(autograd 内蔵なので不要) | | `linear(x, w)` | `nn.Linear(in, out, bias=False)` | | `rmsnorm(x)` | カスタム `RMSNorm(nn.Module)` | | `softmax(logits)` | `F.softmax(x, dim=-1)` | | `gpt(token_id, ...)` | `MicroGPT.forward(tokens)` | | `loss.backward()` | `loss.backward()` | | Adam バッファ手書き | `torch.optim.Adam` | | 手書きの学習率減衰 | `pg["lr"] = lr_t` で同じロジックを再現 | --- ## ディレクトリ構成 ### Python 版(`microgpt_torch.py`) 単一ファイルにすべてを記述したシンプルな実装です。 ``` src/part3/ └── microgpt_torch.py ← データ・モデル・学習・推論をすべて含む ``` ### Mojo 版(`microgpt_torch_mojo/`) Python と Mojo の役割を分離した構成です。 ``` src/part3/microgpt_torch_mojo/ ├── dataset.py ← データ読み込み・トークナイザー(Python が得意な文字列処理) ├── model.py ← nn.Module クラス定義(PyTorch の型システムをそのまま活かす) └── main.mojo ← 設定・学習ループ・推論(Mojo の型安全を活かす) ``` --- ## Python 版ソースコード ```{literalinclude} ../../../src/part3/microgpt_torch.py :language: python ``` --- ## Mojo 版ソースコード ### dataset.py(データ処理) ```{literalinclude} ../../../src/part3/microgpt_torch_mojo/dataset.py :language: python ``` ### model.py(PyTorch モデル定義) ```{literalinclude} ../../../src/part3/microgpt_torch_mojo/model.py :language: python ``` ### main.mojo(Mojo 版メインループ) ```{literalinclude} ../../../src/part3/microgpt_torch_mojo/main.mojo :language: mojo ``` --- ## Python 版 vs Mojo 版:差分解説 ### 1. デバイス選択 **Python 版:** ```python device = "mps" if torch.backends.mps.is_available() else "cpu" ``` **Mojo 版:** ```mojo var device: String = "cpu" if torch.backends.mps.is_available().__bool__(): device = "mps" ``` Mojo では `PythonObject`(PyTorch が返す Python の bool)を Mojo の `Bool` に変換するために `.__bool__()` が必要です。Python では型変換が暗黙的に行われますが、Mojo では明示的に記述します。 --- ### 2〜6. その他の主な差分 | 項目 | Python 版 | Mojo 版 | ポイント | |------|-----------|--------|--------| | ハイパーパラメータ | 変数バラ持ち | `struct GPTConfig` | 型付きで渡し忘れをコンパイル時に検出 | | ループ変数 | Python `int` | Mojo `Int` | 浮動小数との混算に `Float64(step)` が必要 | | Python タプル | `(0.85, 0.99)` | `Python.evaluate("(0.85, 0.99)")` | Mojo タプルは Python タプルと別型 | | dict への書き込み | `pg["lr"] = lr_t` | `pg.__setitem__("lr", value=lr_t)` | `[]` 代入は Mojo では使えない | | マルチインデックス | `logits[0, -1, :]` | `logits[0][-1]` | 段階的にインデックスして代替 | --- ## 実行方法 ```bash # Python 版 uv run python src/part3/microgpt_torch.py # Mojo 版(ディレクトリに移動してから実行) cd src/part3/microgpt_torch_mojo uv run mojo run main.mojo ``` どちらも同じ乱数シード(`seed=42`)を使うため、学習の loss 推移は**一致**します。 --- ## まとめ | 比較項目 | Python 版 | Mojo 版 | |---------|-----------|--------| | ファイル構成 | 単一ファイル | 役割分離(dataset/model/main) | | ハイパーパラメータ | モジュール変数(動的型) | `struct GPTConfig`(静的型) | | デバイス選択 | 暗黙の型変換 | `.__bool__()` で明示変換 | | ループ変数 | Python int | Mojo `Int` | | Python タプル | タプルリテラル | `Python.evaluate()` で生成 | | `nn.Module` の定義 | Python 内 | Python 内(interop で呼び出す) | 次章では、同じ書き直しを Apple Silicon 専用フレームワーク **MLX** で行います。