11. メタプログラミング

11.1. この章で学ぶこと

  • コンパイル時評価の基本

  • []() の違い

  • traits / generics / constraints の考え方

  • materialization と reflection の入口

Mojo のメタプログラミングの核心は、コンパイル時に決められることを先に決め、実行時コストを消すことです。

  • comptime if で選ばれなかった枝はバイナリに残らない

  • comptime for で展開されたループは実行時ループではない

  • [] で渡した型・値は、特化されたコードを生成する材料になる

何がコンパイル時に消えるのか」を意識すると、この章の各節がつながりやすくなります。

11.2. Metaprogramming

次の例は comptime if です。branchコンパイル時にだけ意味を持ち、選ばれた枝のコードだけが生成されます。

def main():
    comptime branch = 1
    comptime if branch == 1:
        print(10)
    else:
        print(0)

リスト-1: meta_comptime_if.mojo

comptime branch = 1 は実行時変数ではなく、コンパイル時の分岐材料です。実行バイナリには print(10) 側だけが残る、と考えるとイメージしやすいです。

詳細: Metaprogramming

出典: Mojo Manual — metaprogramming

補足: GPU やデータ配置の話では、「形をコンパイル時に決める」ことが特に重要になります。

11.3. compile-time evaluation

ここでは、コードの一部をコンパイル時に実行する という考え方を見ます。

11.3.1. 何をするのか

Mojo では、comptime を使って、 実行時ではなくコンパイル時にだけ動く処理 を書けます。

よく出てくるものは次の通りです。

  • comptime ブロック

  • comptime if

  • comptime for

11.3.2. どう考えればよいか

これは、プログラムを実行する前の段階で、

  • 定数を決める

  • 型を組み立てる

  • 分岐を先に確定する

ための仕組みです。

つまり、あとで決める必要がないものは、先に決めてしまう ということです。

ここで大事なのは、何がコンパイル時に評価され、何が実行時に評価されるか を区別することです。

comptime for は、ループを コンパイル時に展開するイメージです。次の例では range(3) の各 i について print がコンパイル時に処理され、実行時には結果の機械語だけが残ります。

def main():
    comptime for i in range(3):
        print(i)

リスト-2: comptime_for_unroll.mojo

実行時の for とは別物なので、ループ上限に comptime な値を使う、などの制約がある場合があります(詳細はマニュアル)。

詳細: compile-time evaluation

出典: Mojo Manual — comptime-evaluation

補足: この節は、デバッグやエラーメッセージの理解にもつながります。

11.4. Parameters

この節で大事なのは、[]() は役割が違う ということです。

11.4.1. 基本

  • [] は compile-time parameter

  • () は runtime argument

つまり、

  • [] はコンパイル時に決まる

  • () は実行時に渡す

という違いがあります。

11.4.2. どう使い分けるか

[] には、型・値・別名などを渡せます。 これは、コードの形そのものをコンパイル時に決めたいときに使います。

一方で () は、実行中に与える普通の引数です。

この違いを意識すると、Mojo のコードはかなり読みやすくなります。

11.4.3. なぜ大事なのか

[] を使うと、コンパイラはその値に合わせて特化したコードを作れます。 そのため、抽象的に書いても実行時には速くしやすくなります。

laneswidth[] で渡すことで、呼び出しごとに別の特化が可能になります。scale() の実行時引数です。

fn lanes[width: Int](scale: Int) -> Int:
    return width * scale


def main():
    print(lanes[8](2))

リスト-3: params_width_runtime_value.mojo

lanes[8](2)「幅 8 に特化した関数」に、実行時の 2 を渡す、と読みます。

実際にコンパイルされたコードは以下のようになり、幅 8 に特化した関数 が生成されていることがわかります。

params_width_runtime_value.o:
(__TEXT,__text) section
_params_width_runtime_value::lanes[::Int](::Int),width=8:
; fn lanes[width: Int](scale: Int) -> Int:
       0:	ff 43 00 d1	sub	sp, sp, #0x10
       4:	e0 07 00 f9	str	x0, [sp, #0x8]
;     return width * scale
       8:	00 f0 7d d3	lsl	x0, x0, #3
       c:	ff 43 00 91	add	sp, sp, #0x10
      10:	c0 03 5f d6	ret
_params_width_runtime_value::main():
; def main():
      14:	ff 83 00 d1	sub	sp, sp, #0x20
      18:	fd 7b 01 a9	stp	x29, x30, [sp, #0x10]
;     print(lanes[8](2))
      1c:	48 00 80 52	mov	w8, #0x2
      20:	e0 03 08 aa	mov	x0, x8
      24:	00 00 00 94	bl	"_params_width_runtime_value::lanes[::Int](::Int),width=8"
      28:	e0 07 00 f9	str	x0, [sp, #0x8]
      2c:	01 00 00 90	adrp	x1, _static_string_a8d4ace0dc8d360e@PAGE
      30:	21 00 00 91	add	x1, x1, _static_string_a8d4ace0dc8d360e@PAGEOFF
      34:	28 00 80 52	mov	w8, #0x1
      38:	e6 03 08 aa	mov	x6, x8
      3c:	e2 03 06 aa	mov	x2, x6
      40:	03 00 00 90	adrp	x3, _static_string_bbe01a6a523daf15@PAGE
      44:	63 00 00 91	add	x3, x3, _static_string_bbe01a6a523daf15@PAGEOFF
      48:	e4 03 06 aa	mov	x4, x6
      4c:	08 00 80 52	mov	w8, #0x0
      50:	05 01 00 12	and	w5, w8, #0x1
      54:	00 00 00 94	bl	"_std::io::io::print[*::Writable](*$0,sep:::StringSlice[::Bool(False), StaticConstantOrigin, *?],end:::StringSlice[::Bool(False), StaticConstantOrigin, *?],flush:::Bool,file:::FileDescriptor$),Ts=[[typevalue<#kgen.instref<\"std::builtin::int::Int\">>, index]]"
; def main():
      58:	fd 7b 41 a9	ldp	x29, x30, [sp, #0x10]
      5c:	ff 83 00 91	add	sp, sp, #0x20
      60:	c0 03 5f d6	ret

リスト-4: params_width_runtime_value.asm

詳細: Parameters

出典: Mojo Manual — parameters

補足: generics と似て見えますが、parameters は「コンパイル時に決まる情報」を直接扱う感覚です。

11.5. Traits

trait は、型がどんな振る舞いを持つかを表す約束 です。

11.5.1. ざっくり言うと

trait は、 この型はこういう操作ができます と示すための仕組みです。

ポイントは次の通りです。

  • trait はメソッドの契約を表す

  • 実際の実装は struct 側に書く

  • 共通の振る舞いをそろえやすくなる

11.5.2. たとえば

標準では、次のような trait が出てきます。

  • Copyable

  • Movable

  • Stringable

これらは、

  • コピーできるか

  • ムーブできるか

  • 文字列として扱いやすいか

といった性質をそろえるために使われます。

Python の interface や protocol を思い出すと、入りやすいです。

Copyable を付けると、コピー用の API.copy() など)が使える、という約束になります。

@fieldwise_init
struct Label(Copyable):
    var text: String


def main():
    var a = Label("x")
    var b = a.copy()
    print(a.text, b.text)

リスト-5: traits_copyable_label.mojo

Labelフィールドの並びと trait で「どう複製できるか」が決まります。別の trait も、満たすメソッドを struct 側に実装する、というパターンです。

詳細: Traits

出典: Mojo Manual — traits

補足: 最初は「trait は振る舞いの条件」と考えるだけで十分です。

11.6. Generics

generics は、型をあとから差し替えられる形でコードを書く仕組み です。

11.6.1. 基本

たとえば T のような型パラメータを使って、 複数の型に対応できる関数や型を書けます。

さらに、T: Writable のように、 この trait を満たす型だけ使える という条件も付けられます。

11.6.2. 何がうれしいのか

  • 同じ形のコードを何度も書かなくてよい

  • 型ごとに安全な制約を付けられる

  • 使った組み合わせごとにコードが生成される

ここで大事なのは、抽象的に書いても、実行時には具体的な型に合わせたコードになる ことです。

また、数値そのものをパラメータにして、長さや幅を固定する書き方もあります。 これが、SIMD や配列サイズの話とつながります。

byte_lenT: Writable な型だけを受け取ります。StringWritable を満たすので、この呼び出しが成立します。

def byte_len[T: Writable](x: T) -> Int:
    return len(String(x))


def main():
    var s = String("abc")
    print(byte_len(s))

リスト-6: generics_writable_len.mojo

String(x)文字列化してから len しています。別の Writable 型を渡すと、別の特化が生成される、という読み方ができます。

これも実際にコンパイルされたコードは以下のようになり、特化した関数 が生成されていることがわかります。

generics_writable_len.o:
(__TEXT,__text) section
_generics_writable_len::byte_len[::Writable]($0),T=[typevalue<#kgen.instref<"std::collections::string::string::String">>, struct<(pointer<none>, index, index) memoryOnly>]:
; def byte_len[T: Writable](x: T) -> Int:
       0:	ff 43 02 d1	sub	sp, sp, #0x90
       4:	fd 7b 08 a9	stp	x29, x30, [sp, #0x80]
       8:	e8 03 00 aa	mov	x8, x0
       c:	e8 2f 00 f9	str	x8, [sp, #0x58]
      10:	03 00 00 90	adrp	x3, _static_string_2d06800538d394c2@PAGE
      14:	63 00 00 91	add	x3, x3, _static_string_2d06800538d394c2@PAGEOFF
;     return len(String(x))
      18:	e1 03 03 aa	mov	x1, x3
      1c:	04 00 80 d2	mov	x4, #0x0
      20:	e2 03 04 aa	mov	x2, x4
      24:	e5 83 00 91	add	x5, sp, #0x20
      28:	e5 0b 00 f9	str	x5, [sp, #0x10]
;     return len(String(x))
      2c:	00 00 00 94	bl	"_std::collections::string::string::String::__init__[*::Writable](*$0,sep:::StringSlice[::Bool(False), StaticConstantOrigin, *?],end:::StringSlice[::Bool(False), StaticConstantOrigin, *?]),Ts=[[typevalue<#kgen.instref<\"std::collections::string::string::String\">>, struct<(pointer<none>, index, index) memoryOnly>]]"
      30:	e0 0b 40 f9	ldr	x0, [sp, #0x10]
      34:	1f 20 03 d5	nop
      38:	00 00 00 94	bl	"_std::collections::string::string::String::__len__(::String)"
      3c:	e0 0f 00 f9	str	x0, [sp, #0x18]
;     return len(String(x))
      40:	e8 1b 40 f9	ldr	x8, [sp, #0x30]
      44:	28 05 f0 b6	tbz	x8, #0x3e, 0xe8
      48:	01 00 00 14	b	0x4c
      4c:	e8 13 40 f9	ldr	x8, [sp, #0x20]
      50:	1f 20 03 d5	nop
      54:	e8 33 00 f9	str	x8, [sp, #0x60]
      58:	08 21 00 f1	subs	x8, x8, #0x8
      5c:	e8 07 00 f9	str	x8, [sp, #0x8]
      60:	e9 03 08 aa	mov	x9, x8
      64:	e9 1f 00 f9	str	x9, [sp, #0x38]
;     return len(String(x))
      68:	e8 23 00 f9	str	x8, [sp, #0x40]
      6c:	1f 20 03 d5	nop
      70:	01 00 00 14	b	0x74
      74:	e9 07 40 f9	ldr	x9, [sp, #0x8]
      78:	e8 03 09 aa	mov	x8, x9
      7c:	e8 27 00 f9	str	x8, [sp, #0x48]
      80:	e8 03 09 aa	mov	x8, x9
      84:	e8 2b 00 f9	str	x8, [sp, #0x50]
      88:	28 00 80 52	mov	w8, #0x1
      8c:	e8 03 08 cb	neg	x8, x8
      90:		.long	0xf8680128
      94:	e8 37 00 f9	str	x8, [sp, #0x68]
      98:	e8 03 00 f9	str	x8, [sp]
      9c:	01 00 00 14	b	0xa0
      a0:	e8 03 40 f9	ldr	x8, [sp]
;     return len(String(x))
      a4:	08 05 00 f1	subs	x8, x8, #0x1
      a8:	c1 01 00 54	b.ne	0xe0
      ac:	01 00 00 14	b	0xb0
      b0:	01 00 00 14	b	0xb4
      b4:	bf 39 03 d5	dmb	ishld
      b8:	01 00 00 14	b	0xbc
      bc:	e0 07 40 f9	ldr	x0, [sp, #0x8]
;     return len(String(x))
      c0:	1f 20 03 d5	nop
      c4:	e8 03 00 aa	mov	x8, x0
      c8:	e8 3b 00 f9	str	x8, [sp, #0x70]
      cc:	1f 20 03 d5	nop
      d0:	e8 03 00 aa	mov	x8, x0
      d4:	e8 3f 00 f9	str	x8, [sp, #0x78]
      d8:	00 00 00 94	bl	_KGEN_CompilerRT_AlignedFree
;     return len(String(x))
      dc:	02 00 00 14	b	0xe4
      e0:	01 00 00 14	b	0xe4
      e4:	02 00 00 14	b	0xec
      e8:	01 00 00 14	b	0xec
      ec:	e0 0f 40 f9	ldr	x0, [sp, #0x18]
;     return len(String(x))
      f0:	fd 7b 48 a9	ldp	x29, x30, [sp, #0x80]
      f4:	ff 43 02 91	add	sp, sp, #0x90
      f8:	c0 03 5f d6	ret
_generics_writable_len::main():
; def main():
      fc:	ff 83 02 d1	sub	sp, sp, #0xa0
     100:	fd 7b 09 a9	stp	x29, x30, [sp, #0x90]
     104:	e0 83 00 91	add	x0, sp, #0x20
;     var s = String("abc")
     108:	68 00 80 52	mov	w8, #0x3
     10c:	e8 17 00 f9	str	x8, [sp, #0x28]
     110:	08 00 00 90	adrp	x8, _static_string_fd849f6ed691be56@PAGE
     114:	08 01 00 91	add	x8, x8, _static_string_fd849f6ed691be56@PAGEOFF
     118:	e9 03 08 aa	mov	x9, x8
     11c:	e9 1f 00 f9	str	x9, [sp, #0x38]
     120:	e9 03 08 aa	mov	x9, x8
     124:	e9 23 00 f9	str	x9, [sp, #0x40]
     128:	e8 13 00 f9	str	x8, [sp, #0x20]
     12c:	08 00 e4 d2	mov	x8, #0x2000000000000000
     130:	e8 1b 00 f9	str	x8, [sp, #0x30]
;     print(byte_len(s))
     134:	00 00 00 94	bl	"_generics_writable_len::byte_len[::Writable]($0),T=[typevalue<#kgen.instref<\"std::collections::string::string::String\">>, struct<(pointer<none>, index, index) memoryOnly>]"
     138:	e0 0f 00 f9	str	x0, [sp, #0x18]
     13c:	e8 1b 40 f9	ldr	x8, [sp, #0x30]
     140:	28 05 f0 b6	tbz	x8, #0x3e, 0x1e4
     144:	01 00 00 14	b	0x148
     148:	e8 13 40 f9	ldr	x8, [sp, #0x20]
     14c:	1f 20 03 d5	nop
     150:	e8 3b 00 f9	str	x8, [sp, #0x70]
     154:	08 21 00 f1	subs	x8, x8, #0x8
     158:	e8 0b 00 f9	str	x8, [sp, #0x10]
     15c:	e9 03 08 aa	mov	x9, x8
     160:	e9 27 00 f9	str	x9, [sp, #0x48]
;     print(byte_len(s))
     164:	e8 2b 00 f9	str	x8, [sp, #0x50]
     168:	1f 20 03 d5	nop
     16c:	01 00 00 14	b	0x170
     170:	e9 0b 40 f9	ldr	x9, [sp, #0x10]
     174:	e8 03 09 aa	mov	x8, x9
     178:	e8 2f 00 f9	str	x8, [sp, #0x58]
     17c:	e8 03 09 aa	mov	x8, x9
     180:	e8 33 00 f9	str	x8, [sp, #0x60]
     184:	28 00 80 52	mov	w8, #0x1
     188:	e8 03 08 cb	neg	x8, x8
     18c:		.long	0xf8680128
     190:	e8 3f 00 f9	str	x8, [sp, #0x78]
     194:	e8 07 00 f9	str	x8, [sp, #0x8]
     198:	01 00 00 14	b	0x19c
     19c:	e8 07 40 f9	ldr	x8, [sp, #0x8]
;     print(byte_len(s))
     1a0:	08 05 00 f1	subs	x8, x8, #0x1
     1a4:	c1 01 00 54	b.ne	0x1dc
     1a8:	01 00 00 14	b	0x1ac
     1ac:	01 00 00 14	b	0x1b0
     1b0:	bf 39 03 d5	dmb	ishld
     1b4:	01 00 00 14	b	0x1b8
     1b8:	e0 0b 40 f9	ldr	x0, [sp, #0x10]
;     print(byte_len(s))
     1bc:	1f 20 03 d5	nop
     1c0:	e8 03 00 aa	mov	x8, x0
     1c4:	e8 43 00 f9	str	x8, [sp, #0x80]
     1c8:	1f 20 03 d5	nop
     1cc:	e8 03 00 aa	mov	x8, x0
     1d0:	e8 47 00 f9	str	x8, [sp, #0x88]
     1d4:	00 00 00 94	bl	_KGEN_CompilerRT_AlignedFree
;     print(byte_len(s))
     1d8:	02 00 00 14	b	0x1e0
     1dc:	01 00 00 14	b	0x1e0
     1e0:	02 00 00 14	b	0x1e8
     1e4:	01 00 00 14	b	0x1e8
     1e8:	e0 0f 40 f9	ldr	x0, [sp, #0x18]
;     print(byte_len(s))
     1ec:	e0 37 00 f9	str	x0, [sp, #0x68]
     1f0:	01 00 00 90	adrp	x1, _static_string_a8d4ace0dc8d360e@PAGE
     1f4:	21 00 00 91	add	x1, x1, _static_string_a8d4ace0dc8d360e@PAGEOFF
     1f8:	28 00 80 52	mov	w8, #0x1
     1fc:	e6 03 08 aa	mov	x6, x8
     200:	e2 03 06 aa	mov	x2, x6
     204:	03 00 00 90	adrp	x3, _static_string_bbe01a6a523daf15@PAGE
     208:	63 00 00 91	add	x3, x3, _static_string_bbe01a6a523daf15@PAGEOFF
     20c:	e4 03 06 aa	mov	x4, x6
     210:	08 00 80 52	mov	w8, #0x0
     214:	05 01 00 12	and	w5, w8, #0x1
     218:	00 00 00 94	bl	"_std::io::io::print[*::Writable](*$0,sep:::StringSlice[::Bool(False), StaticConstantOrigin, *?],end:::StringSlice[::Bool(False), StaticConstantOrigin, *?],flush:::Bool,file:::FileDescriptor$),Ts=[[typevalue<#kgen.instref<\"std::builtin::int::Int\">>, index]]"
; def main():
     21c:	fd 7b 49 a9	ldp	x29, x30, [sp, #0x90]
     220:	ff 83 02 91	add	sp, sp, #0xa0
     224:	c0 03 5f d6	ret

リスト-7: generics_writable_len.asm

詳細: Generics

出典: Mojo Manual — generics

補足: generics は parameters とセットで読むと理解しやすいです。

11.7. Constraints

constraints は、使ってよい型や条件をさらに細かくしぼる仕組み です。

11.7.1. 何をするのか

  • where 句で追加条件を書く

  • comptime の条件で分岐する

  • comptime_assert でコンパイル時に失敗させる

11.7.2. なぜ必要か

抽象的なコードは便利ですが、何でも受け入れると誤用が起きやすくなります。 そこで constraints を使って、 このコードはこういう条件のときだけ使える と明示します。

すると、間違いを実行前に見つけやすくなります。

Buf[size: Int where size > 0]size が正のときだけ型が成立します。Buf[0] のような誤用は コンパイルエラーにできます。

struct Buf[size: Int where size > 0]:
    var data: Int

    def __init__(out self):
        self.data = Self.size


def main():
    var b = Buf[4]()
    print(b.data)

リスト-8: constraints_buf_positive.mojo

Self.sizeその struct のパラメータ size を指します。where の条件は、ジェネリクスや最適化の前提をコンパイラに伝える役割もあります。

詳細: Constraints

出典: Mojo Manual — constraints

補足: ここは少し高度です。必要になった時点で深掘りすれば十分です。

11.8. Materialization

この節は少しわかりにくいですが、 コンパイル時の情報を、実行時に使える値へ落とし込む仕組み だと捉えるとよいでしょう。

11.8.1. イメージ

たとえば、リテラルや comptime の値を、 実際に使う値として形にする場面があります。 そのときに materialization の考え方が出てきます。

11.8.2. まず押さえること

  • コンパイル時の情報を実行時の値へつなぐ

  • materialize() などの仕組みがある

  • 数値型や SIMD の幅の決定とも関わる

ここは最初から完璧に理解しなくても大丈夫です。 まずは、コンパイル時の情報をそのまま終わらせず、実際の値として使う場面がある とわかれば十分です。

リテラルや Int の値を、別の数値型の演算に載せるときは、明示的な変換Float64(n) など)で materialize する、とマニュアルでも説明されます。

def main():
    var n: Int = 21
    var x = Float64(n) / 2.0
    print(x)

リスト-9: materialize_float_from_int.mojo

n は実行時の整数でも、Float64(n)浮動小数の式に載せ替えられます。より高度な materialize() は、型や SIMD の話と合わせてマニュアルを参照してください。

詳細: Materialization

出典: Mojo Manual — materialization

補足: 型やリテラルの話と一緒に読むと、つながりが見えやすくなります。

11.9. Reflection

reflection は、型の情報を調べる仕組み です。

11.9.1. Mojo の reflection の特徴

Mojo では、reflection は主に コンパイル時限定 です。

たとえば次のような情報を調べられます。

  • 型名

  • フィールド情報

  • 型に関する構造

11.9.2. Python との違い

Python では、実行時にかなり自由に型情報を調べられます。 一方で Mojo では、実行時リフレクションは基本的に強く使わない 設計です。

これは、性能や予測しやすさを保つためです。

つまり、Mojo の reflection は、 何でも実行中に調べるためのものではなく、コンパイル時の補助として使うもの と理解するとよいでしょう。

struct_field_countコンパイル時にフィールド数を求めます。comptime n = ... の結果を 実行時の戻り値として返す、という橋渡しもできます。

from std.reflection import struct_field_count

struct Point:
    var x: Int
    var y: Int


def field_count[T: AnyType]() -> Int:
    comptime n = struct_field_count[T]()
    return n


def main():
    print(field_count[Point]())

リスト-10: reflection_field_count.mojo

Point のフィールドは xy2 つなので、2 が表示されます。フィールド名の列挙なども std.reflection にあります(マニュアル参照)。

詳細: Reflection

出典: Mojo Manual — reflection

補足: Python の reflection と同じ感覚で読むとずれやすいので注意です。

11.10. この章を一文で言うと

Mojo のメタプログラミングは、コンパイル時に型や条件を先に決めて、安全で速いコードを作るための仕組みです。

11.11. まとめ

  • Mojo ではコンパイル時に多くのことを決められる

  • [] はコンパイル時、() は実行時の情報を渡す

  • trait は振る舞いの約束を表す

  • generics は型を抽象化しつつ、安全に再利用する仕組み

  • constraints は誤用を早く防ぐための条件づけ

  • materialization はコンパイル時の情報を実際の値につなぐ

  • reflection は主にコンパイル時に型情報を調べるために使う

この章は少し抽象的ですが、 「先に決められることは先に決める」 という一本の考え方で読むと、かなり整理しやすくなります。