Skip to content

Commit 57127c9

Browse files
committed
Fix utils modules naming with lazy loading support
1 parent 9eb3317 commit 57127c9

File tree

12 files changed

+84
-176
lines changed

12 files changed

+84
-176
lines changed

examples/ingress/convert-kernel-bench-to-mlir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Iterable
1616

1717
from mlir import ir, passmanager
18-
from lighthouse.ingress import torch as torch_ingress
18+
import lighthouse.ingress as lh_ingress
1919

2020
project_root = Path(__file__).parent.parent.parent
2121
torch_kernels_dir = project_root / "third_party" / "KernelBench" / "KernelBench"
@@ -173,7 +173,7 @@ def process_task(task: KernelConversionTask):
173173
print("Processing:", kernel_relative_name)
174174

175175
try:
176-
mlir_kernel = torch_ingress.import_from_file(task.torch_path, ir_context=ctx)
176+
mlir_kernel = lh_ingress.torch.import_from_file(task.torch_path, ir_context=ctx)
177177
assert isinstance(mlir_kernel, ir.Module)
178178
except Exception as e:
179179
print(

examples/llama/test_llama3.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@
66
import pytest
77
import torch
88

9-
109
from mlir import ir
1110
from mlir.dialects import transform, func, linalg, tensor, arith, complex, math
1211
from mlir.dialects.linalg import ElementwiseKind
1312
from mlir.dialects.transform import structured, bufferization, interpreter
1413
from mlir.passmanager import PassManager
15-
from mlir.runtime.np_to_memref import (
16-
get_ranked_memref_descriptor,
17-
)
14+
from mlir.runtime.np_to_memref import get_ranked_memref_descriptor
1815
from mlir.execution_engine import ExecutionEngine
1916

17+
from lighthouse import utils as lh_utils
18+
2019
from ref_model import (
2120
Attention,
2221
ModelArgs,
@@ -26,7 +25,6 @@
2625
TransformerBlock,
2726
Transformer,
2827
)
29-
from lighthouse.utils.runtime import ffi as ffi_utils, torch as torch_utils
3028

3129

3230
def with_mlir_ctx_and_location(func):
@@ -1021,7 +1019,7 @@ def bin_op(a, b, out):
10211019
eng = ExecutionEngine(module, opt_level=2)
10221020
func_ptr = eng.lookup("bin_op")
10231021

1024-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1022+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
10251023
a = torch.randn(*shape, dtype=torch_dtype)
10261024
b = torch.randn(*shape, dtype=torch_dtype)
10271025
out_ref = references[op](a, b)
@@ -1031,7 +1029,7 @@ def bin_op(a, b, out):
10311029
a_mem = get_ranked_memref_descriptor(a.numpy())
10321030
b_mem = get_ranked_memref_descriptor(b.numpy())
10331031
out_mem = get_ranked_memref_descriptor(out.numpy())
1034-
args = ffi_utils.memrefs_to_packed_args([a_mem, b_mem, out_mem])
1032+
args = lh_utils.memref.to_packed_args([a_mem, b_mem, out_mem])
10351033
func_ptr(args)
10361034

10371035
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
@@ -1077,14 +1075,14 @@ def unary_op(a, out):
10771075
eng = ExecutionEngine(module, opt_level=2)
10781076
func_ptr = eng.lookup("unary_op")
10791077

1080-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1078+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
10811079
a = torch.randn(*shape, dtype=torch_dtype)
10821080
out_ref = references[op](a)
10831081
out = torch.empty_like(out_ref)
10841082

10851083
a_mem = get_ranked_memref_descriptor(a.numpy())
10861084
out_mem = get_ranked_memref_descriptor(out.numpy())
1087-
args = ffi_utils.memrefs_to_packed_args([a_mem, out_mem])
1085+
args = lh_utils.memref.to_packed_args([a_mem, out_mem])
10881086
func_ptr(args)
10891087

10901088
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
@@ -1113,13 +1111,13 @@ def rms_norm(a, out):
11131111

11141112
eng = ExecutionEngine(module, opt_level=2)
11151113
func_ptr = eng.lookup("rms_norm")
1116-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1114+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
11171115
a = torch.randn(*shape, dtype=torch_dtype)
11181116
out_ref = references[get_l2_norm](a, eps)
11191117
out = torch.empty_like(out_ref)
11201118
a_mem = get_ranked_memref_descriptor(a.numpy())
11211119
out_mem = get_ranked_memref_descriptor(out.numpy())
1122-
args = ffi_utils.memrefs_to_packed_args([a_mem, out_mem])
1120+
args = lh_utils.memref.to_packed_args([a_mem, out_mem])
11231121
func_ptr(args)
11241122

11251123
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
@@ -1161,7 +1159,7 @@ def linear_op(x, w, b, out):
11611159

11621160
eng = ExecutionEngine(module, opt_level=2)
11631161
func_ptr = eng.lookup("linear_op")
1164-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1162+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
11651163
x = torch.randn(*shape, in_features, dtype=torch_dtype)
11661164
w = torch.randn(out_features, in_features, dtype=torch_dtype)
11671165
b = torch.randn(out_features, dtype=torch_dtype)
@@ -1172,7 +1170,7 @@ def linear_op(x, w, b, out):
11721170
w_mem = get_ranked_memref_descriptor(w.numpy())
11731171
b_mem = get_ranked_memref_descriptor(b.numpy())
11741172
out_mem = get_ranked_memref_descriptor(out.numpy())
1175-
args = ffi_utils.memrefs_to_packed_args([x_mem, w_mem, b_mem, out_mem])
1173+
args = lh_utils.memref.to_packed_args([x_mem, w_mem, b_mem, out_mem])
11761174
func_ptr(args)
11771175
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
11781176

@@ -1202,15 +1200,15 @@ def polar_op(magnitude, angle, out):
12021200

12031201
eng = ExecutionEngine(module, opt_level=2)
12041202
func_ptr = eng.lookup("polar_op")
1205-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1203+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
12061204
magnitude = torch.randn(4, 16, dtype=torch_dtype)
12071205
angle = torch.randn(4, 16, dtype=torch_dtype)
12081206
out_ref = references[get_polar](magnitude, angle)
12091207
out = torch.empty_like(out_ref)
12101208
magnitude_mem = get_ranked_memref_descriptor(magnitude.numpy())
12111209
angle_mem = get_ranked_memref_descriptor(angle.numpy())
12121210
out_mem = get_ranked_memref_descriptor(out.numpy())
1213-
args = ffi_utils.memrefs_to_packed_args([magnitude_mem, angle_mem, out_mem])
1211+
args = lh_utils.memref.to_packed_args([magnitude_mem, angle_mem, out_mem])
12141212
func_ptr(args)
12151213
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
12161214

@@ -1238,14 +1236,14 @@ def repeat_kv_op(x, out):
12381236
eng = ExecutionEngine(module, opt_level=2)
12391237
func_ptr = eng.lookup("repeat_kv_op")
12401238

1241-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1239+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
12421240
x = torch.randn(2, 512, 8, 64, dtype=torch_dtype)
12431241
out_ref = references[get_repeat_kv](x, n_rep)
12441242
out = torch.empty_like(out_ref)
12451243

12461244
x_mem = get_ranked_memref_descriptor(x.numpy())
12471245
out_mem = get_ranked_memref_descriptor(out.numpy())
1248-
args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem])
1246+
args = lh_utils.memref.to_packed_args([x_mem, out_mem])
12491247
func_ptr(args)
12501248

12511249
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
@@ -1276,7 +1274,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
12761274
eng = ExecutionEngine(module, opt_level=2)
12771275
func_ptr = eng.lookup("reshape_for_broadcast")
12781276

1279-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1277+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
12801278
freqs_cis = torch.randn(512, 64, dtype=torch_dtype)
12811279
x = torch.randn(2, 512, 32, 128, dtype=torch_dtype)
12821280
# Convert x to complex view as expected by reshape_for_broadcast
@@ -1287,7 +1285,7 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
12871285
freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy())
12881286
x_mem = get_ranked_memref_descriptor(x.numpy())
12891287
out_mem = get_ranked_memref_descriptor(out.numpy())
1290-
args = ffi_utils.memrefs_to_packed_args([freqs_cis_mem, x_mem, out_mem])
1288+
args = lh_utils.memref.to_packed_args([freqs_cis_mem, x_mem, out_mem])
12911289
func_ptr(args)
12921290

12931291
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
@@ -1318,15 +1316,15 @@ def view_as_complex_op(x, out):
13181316
eng = ExecutionEngine(module, opt_level=2)
13191317
func_ptr = eng.lookup("view_as_complex_op")
13201318

1321-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1319+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
13221320
x = torch.randn(2, 512, 32, 128, dtype=torch_dtype)
13231321
x_reshaped = x.reshape(2, 512, 32, 64, 2)
13241322
out_ref = torch.view_as_complex(x_reshaped)
13251323
out = torch.empty_like(out_ref)
13261324

13271325
x_mem = get_ranked_memref_descriptor(x_reshaped.numpy())
13281326
out_mem = get_ranked_memref_descriptor(out.numpy())
1329-
args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem])
1327+
args = lh_utils.memref.to_packed_args([x_mem, out_mem])
13301328
func_ptr(args)
13311329

13321330
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
@@ -1354,15 +1352,15 @@ def as_real_op(x, out):
13541352
eng = ExecutionEngine(module, opt_level=2)
13551353
func_ptr = eng.lookup("as_real_op")
13561354

1357-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1355+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
13581356
x = torch.randn(2, 512, 32, 64, 2, dtype=torch_dtype)
13591357
x_complex = torch.view_as_complex(x)
13601358
out_ref = torch.view_as_real(x_complex)
13611359
out = torch.empty_like(out_ref)
13621360

13631361
x_mem = get_ranked_memref_descriptor(x_complex.numpy())
13641362
out_mem = get_ranked_memref_descriptor(out.numpy())
1365-
args = ffi_utils.memrefs_to_packed_args([x_mem, out_mem])
1363+
args = lh_utils.memref.to_packed_args([x_mem, out_mem])
13661364
func_ptr(args)
13671365

13681366
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)
@@ -1395,7 +1393,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
13951393
return module
13961394

13971395
ir_type = to_ir_type(elem_type)
1398-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1396+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
13991397
xq_shape = (batch_size, seq_len, n_heads, head_dim)
14001398
xk_shape = (batch_size, seq_len, n_kv_heads, head_dim)
14011399
freqs_cis_shape = (seq_len, head_dim // 2)
@@ -1424,7 +1422,7 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
14241422
freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis.numpy())
14251423
out1_mem = get_ranked_memref_descriptor(out1.numpy())
14261424
out2_mem = get_ranked_memref_descriptor(out2.numpy())
1427-
args = ffi_utils.memrefs_to_packed_args(
1425+
args = lh_utils.memref.to_packed_args(
14281426
[a_mem, b_mem, freqs_cis_mem, out1_mem, out2_mem]
14291427
)
14301428
func_ptr(args)
@@ -1489,7 +1487,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
14891487
eng = ExecutionEngine(module, opt_level=2)
14901488
func_ptr = eng.lookup("feed_forward")
14911489

1492-
torch_dtype = torch_utils.mlir_type_to_torch_dtype(ir_type)
1490+
torch_dtype = lh_utils.torch.dtype_from_mlir_type(ir_type)
14931491
x = torch.randn(4, 16, dtype=torch_dtype)
14941492
w1 = torch.randn(64, 16, dtype=torch_dtype)
14951493
b1 = torch.randn(64, dtype=torch_dtype)
@@ -1512,7 +1510,7 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
15121510
w3_mem = get_ranked_memref_descriptor(w3.numpy())
15131511
b3_mem = get_ranked_memref_descriptor(b3.numpy())
15141512
out_mem = get_ranked_memref_descriptor(out.numpy())
1515-
args = ffi_utils.memrefs_to_packed_args(
1513+
args = lh_utils.memref.to_packed_args(
15161514
[x_mem, w1_mem, b1_mem, w2_mem, b2_mem, w3_mem, b3_mem, out_mem]
15171515
)
15181516
func_ptr(args)
@@ -1645,7 +1643,7 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out):
16451643
freqs_cis_mem = get_ranked_memref_descriptor(freqs_cis_real.numpy())
16461644
mask_mem = get_ranked_memref_descriptor(mask.numpy())
16471645
out_mem = get_ranked_memref_descriptor(out.numpy())
1648-
args = ffi_utils.memrefs_to_packed_args(
1646+
args = lh_utils.memref.to_packed_args(
16491647
[x_mem, wq_mem, wk_mem, wv_mem, wo_mem, freqs_cis_mem, mask_mem, out_mem]
16501648
)
16511649
func_ptr(args)
@@ -1792,7 +1790,7 @@ def transformer_block_op(
17921790
b3_mem = get_ranked_memref_descriptor(b3.numpy())
17931791
out_mem = get_ranked_memref_descriptor(out.numpy())
17941792

1795-
args = ffi_utils.memrefs_to_packed_args(
1793+
args = lh_utils.memref.to_packed_args(
17961794
[
17971795
x_mem,
17981796
wq_mem,
@@ -1981,7 +1979,7 @@ def transformer_op(*params):
19811979
out_mem = get_ranked_memref_descriptor(out.numpy())
19821980
memrefs.append(out_mem)
19831981

1984-
args = ffi_utils.memrefs_to_packed_args(memrefs)
1982+
args = lh_utils.memref.to_packed_args(memrefs)
19851983
func_ptr(args)
19861984

19871985
assert torch.allclose(out, out_ref, rtol=0.01, atol=0.01, equal_nan=True)

examples/mlir/compile_and_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mlir.execution_engine import ExecutionEngine
1111
from mlir.passmanager import PassManager
1212

13-
from lighthouse.utils.runtime import torch as torch_utils
13+
import lighthouse.utils as lh_utils
1414

1515

1616
def create_kernel(ctx: ir.Context) -> ir.Module:
@@ -168,7 +168,7 @@ def main(args):
168168
out = torch.empty_like(out_ref)
169169

170170
# Execute the kernel.
171-
args = torch_utils.torch_to_packed_args([a, b, out])
171+
args = lh_utils.torch.to_packed_args([a, b, out])
172172
add_func(args)
173173

174174
### Verification ###

examples/workload/example.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,20 @@
55
"""
66
Workload example: Element-wise sum of two (M, N) float32 arrays on CPU.
77
"""
8+
import ctypes
9+
from contextlib import contextmanager
10+
from functools import cached_property
11+
from typing import Optional
812

913
import numpy as np
1014
from mlir import ir
1115
from mlir.runtime.np_to_memref import get_ranked_memref_descriptor
1216
from mlir.dialects import func, linalg, bufferization
1317
from mlir.dialects import transform
1418
from mlir.execution_engine import ExecutionEngine
15-
from contextlib import contextmanager
16-
from functools import cached_property
17-
import ctypes
18-
from typing import Optional
19-
from lighthouse.utils.mlir import (
20-
apply_registered_pass,
21-
canonicalize,
22-
match,
23-
)
24-
from lighthouse.workload import (
25-
Workload,
26-
execute,
27-
benchmark,
28-
)
19+
20+
from lighthouse.utils.mlir import apply_registered_pass, canonicalize, match
21+
from lighthouse.workload import Workload, execute, benchmark
2922

3023

3124
class ElementwiseSum(Workload):

examples/workload/example_mlir.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
In this example, allocation and deallocation of input arrays is done in MLIR.
99
"""
1010

11+
import ctypes
12+
from contextlib import contextmanager
13+
1114
import numpy as np
1215
from mlir import ir
1316
from mlir.runtime.np_to_memref import (
@@ -17,18 +20,11 @@
1720
)
1821
from mlir.dialects import func, linalg, arith, memref
1922
from mlir.execution_engine import ExecutionEngine
20-
import ctypes
21-
from contextlib import contextmanager
22-
from lighthouse.utils.runtime.ffi import (
23-
get_packed_arg,
24-
memrefs_to_packed_args,
25-
memref_to_ctype,
26-
)
23+
24+
from lighthouse.workload import execute, benchmark
25+
import lighthouse.utils as lh_utils
26+
2727
from example import ElementwiseSum
28-
from lighthouse.workload import (
29-
execute,
30-
benchmark,
31-
)
3228

3329

3430
def emit_host_alloc(suffix: str, element_type: ir.Type, rank: int = 2):
@@ -114,16 +110,16 @@ def _allocate_array(
114110
# construct a memref descriptor for the result memref
115111
shape = (self.M, self.N)
116112
mref = make_nd_memref_descriptor(len(shape), as_ctype(self.dtype))()
117-
ptr_mref = memref_to_ctype(mref)
113+
ptr_mref = lh_utils.memref.to_ctype(mref)
118114
ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape]
119-
alloc_func(get_packed_arg([ptr_mref, *ptr_dims]))
115+
alloc_func(lh_utils.memref.get_packed_arg([ptr_mref, *ptr_dims]))
120116
self.memrefs[name] = mref
121117
return mref
122118

123119
def _deallocate_all(self, execution_engine: ExecutionEngine):
124120
for mref in self.memrefs.values():
125121
dealloc_func = execution_engine.lookup("host_dealloc_f32")
126-
dealloc_func(memrefs_to_packed_args([mref]))
122+
dealloc_func(lh_utils.memref.to_packed_args([mref]))
127123
self.memrefs = {}
128124

129125
def get_input_arrays(
@@ -136,10 +132,9 @@ def get_input_arrays(
136132
# initialize with MLIR
137133
fill_zero_func = execution_engine.lookup("host_fill_constant_zero_f32")
138134
fill_random_func = execution_engine.lookup("host_fill_random_f32")
139-
fill_zero_func(memrefs_to_packed_args([C]))
140-
fill_random_func(memrefs_to_packed_args([A]))
141-
fill_random_func(memrefs_to_packed_args([B]))
142-
135+
fill_zero_func(lh_utils.memref.to_packed_args([C]))
136+
fill_random_func(lh_utils.memref.to_packed_args([A]))
137+
fill_random_func(lh_utils.memref.to_packed_args([B]))
143138
return [A, B, C]
144139

145140
@contextmanager

0 commit comments

Comments
 (0)