@@ -176,6 +176,21 @@ def _reduction_op(self, val, torch_op, list_op):
176176 return list_op (val )
177177 return val
178178
179+ def _to_int (self , x , ctx , context_name = "operation" ):
180+ """Convert value to int, handling tensors and nested lists recursively"""
181+ if self ._is_tensor (x ):
182+ if x .numel () == 1 :
183+ return int (x .item ())
184+ else :
185+ raise ValueError (f"{ ctx .start .line } :{ ctx .start .column } : { context_name } expects scalar dimensions, got tensor with shape { x .shape } " )
186+ elif self ._is_list (x ):
187+ if len (x ) == 1 :
188+ return self ._to_int (x [0 ], ctx , context_name )
189+ else :
190+ raise ValueError (f"{ ctx .start .line } :{ ctx .start .column } : { context_name } expects scalar dimensions, got list with { len (x )} elements" )
191+ else :
192+ return int (float (x ))
193+
179194 # ========================
180195 # Visitors
181196 # ========================
@@ -996,14 +1011,7 @@ def visitReshapeFunc(self, ctx):
9961011 elif isinstance (new_shape , (list , tuple )):
9971012 result = []
9981013 for d in new_shape :
999- if self ._is_tensor (d ):
1000- # Handle tensor elements in list
1001- if d .numel () == 1 :
1002- result .append (int (d .item ()))
1003- else :
1004- raise ValueError (f"{ ctx .start .line } :{ ctx .start .column } : reshape expects scalar dimensions, got tensor with shape { d .shape } " )
1005- else :
1006- result .append (int (float (d )))
1014+ result .append (self ._to_int (d , ctx , "reshape" ))
10071015 new_shape = result
10081016 elif isinstance (new_shape , (int , float )):
10091017 new_shape = [int (float (new_shape ))]
0 commit comments