66import pytest
77import torch
88
9-
109from mlir import ir
1110from mlir .dialects import transform , func , linalg , tensor , arith , complex , math
1211from mlir .dialects .linalg import ElementwiseKind
1312from mlir .dialects .transform import structured , bufferization , interpreter
1413from mlir .passmanager import PassManager
15- from mlir .runtime .np_to_memref import (
16- get_ranked_memref_descriptor ,
17- )
14+ from mlir .runtime .np_to_memref import get_ranked_memref_descriptor
1815from mlir .execution_engine import ExecutionEngine
1916
17+ from lighthouse import utils as lh_utils
18+
2019from ref_model import (
2120 Attention ,
2221 ModelArgs ,
2625 TransformerBlock ,
2726 Transformer ,
2827)
29- from lighthouse .utils .runtime import ffi as ffi_utils , torch as torch_utils
3028
3129
3230def with_mlir_ctx_and_location (func ):
@@ -1021,7 +1019,7 @@ def bin_op(a, b, out):
10211019 eng = ExecutionEngine (module , opt_level = 2 )
10221020 func_ptr = eng .lookup ("bin_op" )
10231021
1024- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1022+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
10251023 a = torch .randn (* shape , dtype = torch_dtype )
10261024 b = torch .randn (* shape , dtype = torch_dtype )
10271025 out_ref = references [op ](a , b )
@@ -1031,7 +1029,7 @@ def bin_op(a, b, out):
10311029 a_mem = get_ranked_memref_descriptor (a .numpy ())
10321030 b_mem = get_ranked_memref_descriptor (b .numpy ())
10331031 out_mem = get_ranked_memref_descriptor (out .numpy ())
1034- args = ffi_utils . memrefs_to_packed_args ([a_mem , b_mem , out_mem ])
1032+ args = lh_utils . memref . to_packed_args ([a_mem , b_mem , out_mem ])
10351033 func_ptr (args )
10361034
10371035 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
@@ -1077,14 +1075,14 @@ def unary_op(a, out):
10771075 eng = ExecutionEngine (module , opt_level = 2 )
10781076 func_ptr = eng .lookup ("unary_op" )
10791077
1080- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1078+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
10811079 a = torch .randn (* shape , dtype = torch_dtype )
10821080 out_ref = references [op ](a )
10831081 out = torch .empty_like (out_ref )
10841082
10851083 a_mem = get_ranked_memref_descriptor (a .numpy ())
10861084 out_mem = get_ranked_memref_descriptor (out .numpy ())
1087- args = ffi_utils . memrefs_to_packed_args ([a_mem , out_mem ])
1085+ args = lh_utils . memref . to_packed_args ([a_mem , out_mem ])
10881086 func_ptr (args )
10891087
10901088 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
@@ -1113,13 +1111,13 @@ def rms_norm(a, out):
11131111
11141112 eng = ExecutionEngine (module , opt_level = 2 )
11151113 func_ptr = eng .lookup ("rms_norm" )
1116- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1114+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
11171115 a = torch .randn (* shape , dtype = torch_dtype )
11181116 out_ref = references [get_l2_norm ](a , eps )
11191117 out = torch .empty_like (out_ref )
11201118 a_mem = get_ranked_memref_descriptor (a .numpy ())
11211119 out_mem = get_ranked_memref_descriptor (out .numpy ())
1122- args = ffi_utils . memrefs_to_packed_args ([a_mem , out_mem ])
1120+ args = lh_utils . memref . to_packed_args ([a_mem , out_mem ])
11231121 func_ptr (args )
11241122
11251123 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
@@ -1161,7 +1159,7 @@ def linear_op(x, w, b, out):
11611159
11621160 eng = ExecutionEngine (module , opt_level = 2 )
11631161 func_ptr = eng .lookup ("linear_op" )
1164- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1162+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
11651163 x = torch .randn (* shape , in_features , dtype = torch_dtype )
11661164 w = torch .randn (out_features , in_features , dtype = torch_dtype )
11671165 b = torch .randn (out_features , dtype = torch_dtype )
@@ -1172,7 +1170,7 @@ def linear_op(x, w, b, out):
11721170 w_mem = get_ranked_memref_descriptor (w .numpy ())
11731171 b_mem = get_ranked_memref_descriptor (b .numpy ())
11741172 out_mem = get_ranked_memref_descriptor (out .numpy ())
1175- args = ffi_utils . memrefs_to_packed_args ([x_mem , w_mem , b_mem , out_mem ])
1173+ args = lh_utils . memref . to_packed_args ([x_mem , w_mem , b_mem , out_mem ])
11761174 func_ptr (args )
11771175 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
11781176
@@ -1202,15 +1200,15 @@ def polar_op(magnitude, angle, out):
12021200
12031201 eng = ExecutionEngine (module , opt_level = 2 )
12041202 func_ptr = eng .lookup ("polar_op" )
1205- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1203+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
12061204 magnitude = torch .randn (4 , 16 , dtype = torch_dtype )
12071205 angle = torch .randn (4 , 16 , dtype = torch_dtype )
12081206 out_ref = references [get_polar ](magnitude , angle )
12091207 out = torch .empty_like (out_ref )
12101208 magnitude_mem = get_ranked_memref_descriptor (magnitude .numpy ())
12111209 angle_mem = get_ranked_memref_descriptor (angle .numpy ())
12121210 out_mem = get_ranked_memref_descriptor (out .numpy ())
1213- args = ffi_utils . memrefs_to_packed_args ([magnitude_mem , angle_mem , out_mem ])
1211+ args = lh_utils . memref . to_packed_args ([magnitude_mem , angle_mem , out_mem ])
12141212 func_ptr (args )
12151213 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
12161214
@@ -1238,14 +1236,14 @@ def repeat_kv_op(x, out):
12381236 eng = ExecutionEngine (module , opt_level = 2 )
12391237 func_ptr = eng .lookup ("repeat_kv_op" )
12401238
1241- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1239+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
12421240 x = torch .randn (2 , 512 , 8 , 64 , dtype = torch_dtype )
12431241 out_ref = references [get_repeat_kv ](x , n_rep )
12441242 out = torch .empty_like (out_ref )
12451243
12461244 x_mem = get_ranked_memref_descriptor (x .numpy ())
12471245 out_mem = get_ranked_memref_descriptor (out .numpy ())
1248- args = ffi_utils . memrefs_to_packed_args ([x_mem , out_mem ])
1246+ args = lh_utils . memref . to_packed_args ([x_mem , out_mem ])
12491247 func_ptr (args )
12501248
12511249 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
@@ -1276,7 +1274,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
12761274 eng = ExecutionEngine (module , opt_level = 2 )
12771275 func_ptr = eng .lookup ("reshape_for_broadcast" )
12781276
1279- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1277+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
12801278 freqs_cis = torch .randn (512 , 64 , dtype = torch_dtype )
12811279 x = torch .randn (2 , 512 , 32 , 128 , dtype = torch_dtype )
12821280 # Convert x to complex view as expected by reshape_for_broadcast
@@ -1287,7 +1285,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
12871285 freqs_cis_mem = get_ranked_memref_descriptor (freqs_cis .numpy ())
12881286 x_mem = get_ranked_memref_descriptor (x .numpy ())
12891287 out_mem = get_ranked_memref_descriptor (out .numpy ())
1290- args = ffi_utils . memrefs_to_packed_args ([freqs_cis_mem , x_mem , out_mem ])
1288+ args = lh_utils . memref . to_packed_args ([freqs_cis_mem , x_mem , out_mem ])
12911289 func_ptr (args )
12921290
12931291 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
@@ -1318,15 +1316,15 @@ def view_as_complex_op(x, out):
13181316 eng = ExecutionEngine (module , opt_level = 2 )
13191317 func_ptr = eng .lookup ("view_as_complex_op" )
13201318
1321- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1319+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
13221320 x = torch .randn (2 , 512 , 32 , 128 , dtype = torch_dtype )
13231321 x_reshaped = x .reshape (2 , 512 , 32 , 64 , 2 )
13241322 out_ref = torch .view_as_complex (x_reshaped )
13251323 out = torch .empty_like (out_ref )
13261324
13271325 x_mem = get_ranked_memref_descriptor (x_reshaped .numpy ())
13281326 out_mem = get_ranked_memref_descriptor (out .numpy ())
1329- args = ffi_utils . memrefs_to_packed_args ([x_mem , out_mem ])
1327+ args = lh_utils . memref . to_packed_args ([x_mem , out_mem ])
13301328 func_ptr (args )
13311329
13321330 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
@@ -1354,15 +1352,15 @@ def as_real_op(x, out):
13541352 eng = ExecutionEngine (module , opt_level = 2 )
13551353 func_ptr = eng .lookup ("as_real_op" )
13561354
1357- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1355+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
13581356 x = torch .randn (2 , 512 , 32 , 64 , 2 , dtype = torch_dtype )
13591357 x_complex = torch .view_as_complex (x )
13601358 out_ref = torch .view_as_real (x_complex )
13611359 out = torch .empty_like (out_ref )
13621360
13631361 x_mem = get_ranked_memref_descriptor (x_complex .numpy ())
13641362 out_mem = get_ranked_memref_descriptor (out .numpy ())
1365- args = ffi_utils . memrefs_to_packed_args ([x_mem , out_mem ])
1363+ args = lh_utils . memref . to_packed_args ([x_mem , out_mem ])
13661364 func_ptr (args )
13671365
13681366 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
@@ -1395,7 +1393,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
13951393 return module
13961394
13971395 ir_type = to_ir_type (elem_type )
1398- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1396+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
13991397 xq_shape = (batch_size , seq_len , n_heads , head_dim )
14001398 xk_shape = (batch_size , seq_len , n_kv_heads , head_dim )
14011399 freqs_cis_shape = (seq_len , head_dim // 2 )
@@ -1424,7 +1422,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
14241422 freqs_cis_mem = get_ranked_memref_descriptor (freqs_cis .numpy ())
14251423 out1_mem = get_ranked_memref_descriptor (out1 .numpy ())
14261424 out2_mem = get_ranked_memref_descriptor (out2 .numpy ())
1427- args = ffi_utils . memrefs_to_packed_args (
1425+ args = lh_utils . memref . to_packed_args (
14281426 [a_mem , b_mem , freqs_cis_mem , out1_mem , out2_mem ]
14291427 )
14301428 func_ptr (args )
@@ -1489,7 +1487,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
14891487 eng = ExecutionEngine (module , opt_level = 2 )
14901488 func_ptr = eng .lookup ("feed_forward" )
14911489
1492- torch_dtype = torch_utils . mlir_type_to_torch_dtype (ir_type )
1490+ torch_dtype = lh_utils . torch . dtype_from_mlir_type (ir_type )
14931491 x = torch .randn (4 , 16 , dtype = torch_dtype )
14941492 w1 = torch .randn (64 , 16 , dtype = torch_dtype )
14951493 b1 = torch .randn (64 , dtype = torch_dtype )
@@ -1512,7 +1510,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
15121510 w3_mem = get_ranked_memref_descriptor (w3 .numpy ())
15131511 b3_mem = get_ranked_memref_descriptor (b3 .numpy ())
15141512 out_mem = get_ranked_memref_descriptor (out .numpy ())
1515- args = ffi_utils . memrefs_to_packed_args (
1513+ args = lh_utils . memref . to_packed_args (
15161514 [x_mem , w1_mem , b1_mem , w2_mem , b2_mem , w3_mem , b3_mem , out_mem ]
15171515 )
15181516 func_ptr (args )
@@ -1645,7 +1643,7 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out):
16451643 freqs_cis_mem = get_ranked_memref_descriptor (freqs_cis_real .numpy ())
16461644 mask_mem = get_ranked_memref_descriptor (mask .numpy ())
16471645 out_mem = get_ranked_memref_descriptor (out .numpy ())
1648- args = ffi_utils . memrefs_to_packed_args (
1646+ args = lh_utils . memref . to_packed_args (
16491647 [x_mem , wq_mem , wk_mem , wv_mem , wo_mem , freqs_cis_mem , mask_mem , out_mem ]
16501648 )
16511649 func_ptr (args )
@@ -1792,7 +1790,7 @@ def transformer_block_op(
17921790 b3_mem = get_ranked_memref_descriptor (b3 .numpy ())
17931791 out_mem = get_ranked_memref_descriptor (out .numpy ())
17941792
1795- args = ffi_utils . memrefs_to_packed_args (
1793+ args = lh_utils . memref . to_packed_args (
17961794 [
17971795 x_mem ,
17981796 wq_mem ,
@@ -1981,7 +1979,7 @@ def transformer_op(*params):
19811979 out_mem = get_ranked_memref_descriptor (out .numpy ())
19821980 memrefs .append (out_mem )
19831981
1984- args = ffi_utils . memrefs_to_packed_args (memrefs )
1982+ args = lh_utils . memref . to_packed_args (memrefs )
19851983 func_ptr (args )
19861984
19871985 assert torch .allclose (out , out_ref , rtol = 0.01 , atol = 0.01 , equal_nan = True )
0 commit comments