Skip to content

Commit 86c74a9

Browse files
committed
[Frontend+Test] Support scatter pattern with a test case
1 parent eff63d0 commit 86c74a9

4 files changed

Lines changed: 205 additions & 50 deletions

File tree

PyTorchSimFrontend/extension_device.cpp

Lines changed: 152 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -424,56 +424,165 @@ void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack
424424
}
425425

426426
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
427-
m.impl("add.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
428-
m.impl("add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
429-
m.impl("abs.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
430-
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
431-
m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
432-
m.impl("div.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
433-
m.impl("pow.Tensor_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
434-
m.impl("zero_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
435-
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
436-
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
437-
m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
438-
m.impl("neg.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
439-
m.impl("sum.IntList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
440-
m.impl("eq.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
427+
m.impl("abs", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
428+
m.impl("abs.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
429+
m.impl("abs_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
430+
m.impl("absolute", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
431+
m.impl("absolute.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
432+
m.impl("absolute_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
433+
m.impl("add.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
434+
m.impl("add.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
435+
m.impl("add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
436+
m.impl("add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
437+
m.impl("add_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
438+
439+
m.impl("cat", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
440+
m.impl("cat.names", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
441+
m.impl("cat.names_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
442+
m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
443+
444+
m.impl("div.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
445+
m.impl("div.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
446+
m.impl("div.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
447+
m.impl("div_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
448+
m.impl("div_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
449+
450+
m.impl("eq.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
451+
m.impl("eq.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
452+
m.impl("eq.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
453+
m.impl("eq.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
454+
m.impl("equal", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
455+
456+
m.impl("erf", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
457+
m.impl("erf.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
458+
m.impl("erf_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
459+
m.impl("erfc", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
460+
m.impl("erfc.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
461+
m.impl("erfc_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
462+
463+
m.impl("exp", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
464+
m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
465+
466+
m.impl("ge.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
467+
m.impl("ge.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
468+
m.impl("ge.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
469+
m.impl("ge.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
470+
m.impl("gt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
471+
m.impl("gt.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
472+
m.impl("gt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
473+
m.impl("gt.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
474+
m.impl("le.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
475+
m.impl("le.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
476+
m.impl("le.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
477+
m.impl("le.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
478+
m.impl("lt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
479+
m.impl("lt.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
480+
m.impl("lt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
481+
m.impl("lt.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
482+
m.impl("ne.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
483+
m.impl("ne.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
484+
m.impl("ne.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
485+
m.impl("ne.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
486+
487+
m.impl("logical_and", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
488+
m.impl("logical_and.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
489+
m.impl("logical_and_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
490+
m.impl("logical_not", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
491+
m.impl("logical_not.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
492+
m.impl("logical_not_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
493+
m.impl("logical_or", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
494+
m.impl("logical_or.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
495+
m.impl("logical_or_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
496+
m.impl("logical_xor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
497+
m.impl("logical_xor.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
498+
m.impl("logical_xor_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
499+
500+
m.impl("neg", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
501+
m.impl("neg.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
502+
m.impl("neg_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
503+
504+
m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
505+
m.impl("mul.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
506+
m.impl("mul_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
507+
508+
m.impl("pow.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
509+
m.impl("pow.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
510+
m.impl("pow.Tensor_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
511+
m.impl("pow.Tensor_Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
512+
m.impl("pow.Tensor_Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
513+
m.impl("pow.Tensor_Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
514+
m.impl("pow_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
515+
m.impl("pow_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
516+
517+
m.impl("sub.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
518+
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
519+
m.impl("sub.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
520+
m.impl("sub_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
521+
m.impl("sub_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
522+
523+
m.impl("sum", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
524+
m.impl("sum.DimnameList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
525+
m.impl("sum.IntList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
526+
m.impl("sum.dim_DimnameList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
527+
m.impl("sum.dim_IntList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
528+
529+
m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
530+
m.impl("resize_as_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
531+
532+
// Foreach ops
533+
m.impl("_foreach_add.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
534+
m.impl("_foreach_add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
535+
m.impl("_foreach_add_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
536+
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
537+
m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
538+
539+
// Indexed
540+
m.impl("index_add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
541+
m.impl("index_add_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
542+
m.impl("index_copy.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
543+
m.impl("index_copy_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
544+
m.impl("index_fill.int_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
545+
m.impl("index_fill.int_Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
546+
m.impl("index_fill.int_Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
547+
m.impl("index_fill.int_Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
548+
m.impl("index_fill_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
549+
550+
m.impl("tril", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
551+
m.impl("tril_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
552+
m.impl("triu", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
553+
m.impl("triu_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
554+
m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
555+
556+
m.impl("nll_loss2d_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
557+
m.impl("nll_loss2d_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
558+
m.impl("nll_loss_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
559+
m.impl("nll_loss_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
560+
561+
m.impl("scatter.src_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
562+
m.impl("scatter.value_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
563+
564+
m.impl("index_put.Default", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
565+
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
566+
567+
m.impl("mm.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
568+
m.impl("sigmoid.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
569+
m.impl("gather.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
570+
m.impl("silu.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
571+
441572
m.impl("all.all_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
442573
m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
443574
m.impl("_log_softmax", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
444575
m.impl("_log_softmax_backward_data", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
445576
m.impl("mse_loss.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
446-
m.impl("nll_loss_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
447-
m.impl("nll_loss_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
448-
m.impl("_foreach_lerp_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
449-
m.impl("_foreach_mul_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
450-
m.impl("_foreach_addcmul_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
451-
m.impl("_foreach_sqrt", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
452-
m.impl("_foreach_div_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
453-
m.impl("_foreach_add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
454-
m.impl("_foreach_addcdiv_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
455-
m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
456-
m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
457577
m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
458-
m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
459-
m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
460578
m.impl("where.self", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
461-
m.impl("ge.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
462-
m.impl("ge.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
463-
m.impl("le.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
464-
m.impl("le.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
465-
m.impl("lt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
466-
m.impl("lt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
467-
m.impl("gt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
468-
m.impl("gt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
469-
m.impl("triu", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
470-
m.impl("tril", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
471-
m.impl("logical_and.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
472-
m.impl("logical_and.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
473-
m.impl("logical_or.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
474-
m.impl("logical_or.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
475-
m.impl("logical_not.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
476-
m.impl("logical_not.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
579+
m.impl("min", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
580+
m.impl("max", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
581+
m.impl("index_select", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
582+
m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
583+
584+
m.impl("zero_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
585+
m.impl("zeros_like", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
477586
}
478587

479588
// This basic implementation doesn't bother dealing with different device indices

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,12 +1157,24 @@ def load(self, name: str, index: sympy.Expr):
11571157
self.spad_buffer_dict[str(out)] = [sram_var, local_tile_desc.get_tile_size(), tile_numel_per_lane, sram_index_var, tile_shape, vshape]
11581158
return out
11591159

1160-
def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
1160+
def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs):
11611161
index = self.rename_indexing(index)
1162-
dram_var = self.kernel_group.args.output(name)
11631162
dtype = V.graph.get_dtype(name)
11641163
mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype]
11651164

1165+
# Handle scatter store
1166+
if "tmp" in str(index):
1167+
if mode == "atomic_add":
1168+
# Convert the output buffer type to the inplace buffer
1169+
arg_name = V.graph.scheduler.mutation_real_name.get(name, name)
1170+
if arg_name not in self.kernel_group.args.inplace_buffers:
1171+
self.kernel_group.args.make_inplace(arg_name, arg_name)
1172+
1173+
loaded_value = ops.load(name, index)
1174+
value = ops.add(loaded_value, value)
1175+
index, _ = self.convert_indirect_indexing(index)
1176+
dram_var = self.kernel_group.args.output(name)
1177+
11661178
# Prepare dma instruction
11671179
local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index)
11681180
vlane_split_axis = local_tile_desc.vlane_split_axis
@@ -1654,9 +1666,9 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe
16541666
total_dims = [int(str(i)[5:]) for i in self.itervars]
16551667
local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane)
16561668
local_dims.sort() # Assume that smaller index is placed in the outer loop
1657-
indirect_dims = [f"{i}" for i in index.free_symbols if "tmp" in str(i)]
1658-
for indirect_dim in indirect_dims:
1659-
index = index.replace(sympy.Symbol(indirect_dim), 0)
1669+
indirect_syms = [s for s in index.free_symbols if "tmp" in s.name]
1670+
index = index.subs({s: 0 for s in indirect_syms}, simultaneous=True)
1671+
indirect_dims = [f"{i}" for i in indirect_syms]
16601672

16611673
# Reduction can have two type of tile size
16621674
if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)):

0 commit comments

Comments
 (0)