(MAX でコードを書く)= # MAX でコードを書く ## この章で学ぶこと - MAX の Python API(`max.graph`)でグラフを定義・実行する基本パターン - `Graph`・`ops`・`Weight`・`Module` を使った具体的なコード例 - microgpt の演算(linear・rmsnorm・softmax・attention)を MAX で書く方法 --- ## インストール ```bash pip install max ``` または `uv` を使う場合: ```bash uv add max ``` --- ## 最小の Graph 例 — テンソルの加算 まず最も小さな例として、二つのテンソルを加算するグラフを作ります。 ```python import numpy as np from max.graph import Graph, TensorType from max.graph import ops from max.dtype import DType from max import engine # グラフを定義 with Graph("add_example") as g: # 入力の型・形状を宣言(実際の値はまだ渡さない) a = g.input(TensorType(DType.float32, shape=[4])) b = g.input(TensorType(DType.float32, shape=[4])) out = ops.add(a, b) g.output(out) # InferenceSession でコンパイル・実行 session = engine.InferenceSession() model = session.load(g) a_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) b_val = np.array([10.0, 20.0, 30.0, 40.0], dtype=np.float32) result = model.execute(a=a_val, b=b_val)[0] print(result) # [11. 22. 33. 44.] ``` **重要なポイント:** - `Graph` ブロック内では演算を**定義するだけ**で、実際の計算は走らない - `session.load(g)` でコンパイル・最適化が行われる - `model.execute(...)` で初めて値が確定する これは Mojo の Tape 方式(演算のたびに即座に値を計算)とは逆の順序です。 --- ## linear 層を Graph で書く microgpt の `linear(x, w)` = $\mathbf{out} = W\mathbf{x}$ に相当するグラフです。 ```python from max.graph import Graph, TensorType, Weight from max.graph import ops from max.dtype import DType from max import engine import numpy as np n_in = 16 # 入力次元 n_out = 27 # 出力次元(語彙サイズ相当) # 重みを初期化(学習済みの値を想定) w_init = np.random.randn(n_out, n_in).astype(np.float32) * 0.08 with Graph("linear_example") as g: x = g.input(TensorType(DType.float32, shape=[n_in])) # Weight: 学習可能なパラメータとして宣言 w = Weight(TensorType(DType.float32, shape=[n_out, n_in]), name="w", default_value=w_init) # matmul: out = W @ x (形状 [n_out, n_in] × [n_in] → [n_out]) out = ops.matmul(w, ops.reshape(x, [n_in, 1])) out = ops.reshape(out, [n_out]) g.output(out) session = engine.InferenceSession() model = session.load(g) x_val = np.random.randn(n_in).astype(np.float32) result = model.execute(x=x_val)[0] print(result.shape) # (27,) ``` `Weight` を使うと、入力テンソル(毎回変わる値)と重み(固定またはロード済みの値)が**グラフ上で明確に区別**されます。 --- ## Module を使ったモデル定義 `Module` を使うと、複数の重みと演算をひとまとまりで管理できます。 PyTorch の `nn.Module` に近い感覚です。 ```python from max.graph import Graph, TensorType, Weight, Module from max.graph import ops from max.dtype import DType from max import engine import numpy as np class LinearLayer(Module): """1層の線形変換を表す Module""" def __init__(self, n_in: int, n_out: int, w_init: np.ndarray): self.w = Weight( TensorType(DType.float32, shape=[n_out, n_in]), name="w", default_value=w_init, ) def __call__(self, x): # x: [n_in] → out: [n_out] out = ops.matmul(self.w, ops.reshape(x, [-1, 1])) return ops.reshape(out, [-1]) n_in, n_out = 16, 27 w_init = np.random.randn(n_out, n_in).astype(np.float32) * 0.08 with Graph("module_example") as g: layer = LinearLayer(n_in, n_out, w_init) x = g.input(TensorType(DType.float32, shape=[n_in])) g.output(layer(x)) session = engine.InferenceSession() model = session.load(g) result = model.execute(x=np.random.randn(n_in).astype(np.float32))[0] print(result.shape) # (27,) ``` --- ## microgpt の主要演算を MAX で書く ### rmsnorm $$\text{out}[i] = \frac{x[i]}{\sqrt{\frac{1}{n}\sum_j x[j]^2 + \varepsilon}}$$ ```python def rmsnorm_graph(x, eps: float = 1e-5): """x: [n_embd] → out: [n_embd]""" sq = ops.mul(x, x) # x² ms = ops.mean(sq, axis=0) # mean(x²) inv_rms = ops.rsqrt(ops.add(ms, eps)) # 1/sqrt(mean(x²)+ε) return ops.mul(x, inv_rms) ``` ### softmax ```python def softmax_stable(logits): """logits: [vocab_size] → probs: [vocab_size]""" return ops.softmax(logits, axis=0) # MAX の ops.softmax は数値安定化(max引き算)を内部で行う ``` ### scaled dot-product attention(1ヘッド分) $$\text{out} = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$ ```python import math def attention_head(q_h, k_cache, v_cache, head_dim: int): """ q_h: [head_dim] k_cache: [seq_len, head_dim] v_cache: [seq_len, head_dim] """ scale = 1.0 / math.sqrt(head_dim) # QK^T / sqrt(d_k): [seq_len] scores = ops.mul(ops.matmul(k_cache, ops.reshape(q_h, [head_dim, 1])), scale) scores = ops.reshape(scores, [-1]) weights = ops.softmax(scores, axis=0) # [seq_len] # 重み付き和: [head_dim] weights_col = ops.reshape(weights, [1, -1]) # [1, seq_len] out = ops.matmul(weights_col, v_cache) # [1, head_dim] return ops.reshape(out, [head_dim]) ``` --- ## Mojo カスタム ops の組み込み方(概念) 標準の `ops` にない演算や、特定のハードウェアに最適化したい演算は Mojo で `custom op` として書いて Graph に組み込めます。 ```mojo # my_op.mojo from max.graph import CustomOp, TensorRef struct MyFusedOp(CustomOp): fn forward(self, x: TensorRef[DType.float32]) -> TensorRef[DType.float32]: # カスタムの高速実装 ... ``` ```python # Python 側でグラフに組み込む from max.graph import ops out = ops.custom("my_fused_op", inputs=[x], output_types=[x.type]) ``` MAX の `custom op` は、演算の核だけを Mojo で書き、グラフ全体の最適化(融合・スケジューリング)は MAX に任せるというハイブリッドな設計です。 --- ## まとめ | 概念 | 役割 | 対応する microgpt の概念 | |------|------|------------------------| | `Graph` | 計算の設計図 | Tape(ノードの列) | | `TensorType` | 入力の型・形状の宣言 | — | | `ops.*` | 演算ノードの追加 | `t.add`, `t.mul`, ... | | `Weight` | 学習パラメータの宣言 | `flatten_params` の各ノード | | `Module` | 層のまとまり | StateDict の各行列 | | `InferenceSession` | コンパイル・実行 | `gpt_forward` の呼び出し | 次章では、Mojo 版 microgpt で学習した重みを MAX に渡し、推論を MAX で実行する具体例を示します。 ## 次に読む章 {numref}`microgpt Mojo 版を MAX で高速化する`({ref}`microgpt Mojo 版を MAX で高速化する`)へ進みます。