Skip to content

Commit bfe0fd2

Browse files
committed
fix wierd bug with reshape (AI)
1 parent 7f4e0e5 commit bfe0fd2

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

more_math/Parser/UnifiedMathVisitor.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,21 @@ def _reduction_op(self, val, torch_op, list_op):
176176
return list_op(val)
177177
return val
178178

179+
def _to_int(self, x, ctx, context_name="operation"):
180+
"""Convert value to int, handling tensors and nested lists recursively"""
181+
if self._is_tensor(x):
182+
if x.numel() == 1:
183+
return int(x.item())
184+
else:
185+
raise ValueError(f"{ctx.start.line}:{ctx.start.column}: {context_name} expects scalar dimensions, got tensor with shape {x.shape}")
186+
elif self._is_list(x):
187+
if len(x) == 1:
188+
return self._to_int(x[0], ctx, context_name)
189+
else:
190+
raise ValueError(f"{ctx.start.line}:{ctx.start.column}: {context_name} expects scalar dimensions, got list with {len(x)} elements")
191+
else:
192+
return int(float(x))
193+
179194
# ========================
180195
# Visitors
181196
# ========================
@@ -996,14 +1011,7 @@ def visitReshapeFunc(self, ctx):
9961011
elif isinstance(new_shape, (list, tuple)):
9971012
result = []
9981013
for d in new_shape:
999-
if self._is_tensor(d):
1000-
# Handle tensor elements in list
1001-
if d.numel() == 1:
1002-
result.append(int(d.item()))
1003-
else:
1004-
raise ValueError(f"{ctx.start.line}:{ctx.start.column}: reshape expects scalar dimensions, got tensor with shape {d.shape}")
1005-
else:
1006-
result.append(int(float(d)))
1014+
result.append(self._to_int(d, ctx, "reshape"))
10071015
new_shape = result
10081016
elif isinstance(new_shape, (int, float)):
10091017
new_shape = [int(float(new_shape))]

0 commit comments

Comments
 (0)