Repro:
import thunder
import torch
import math
def fn(x, y):
return x + y
tfn = thunder.jit(fn, cache="symbolic values")
tfn(1, 2)
print(thunder.last_traces(tfn)[-1])
# def computation(x, y):
# # x: "int 1"
# # y: "int 2"
# # /opt/pytorch/lightning-thunder/test.py:6: return x + y
# i3 = operator.add(x, y) # i3: "int 3"
# # i3 = prims.add(x, y) # i3: "int 3"
# return (i3,)
tfn(1, 3)
print(thunder.last_traces(tfn)[-1])
# def computation(x, y):
# # x: "int 1"
# # y: "int 2"
# # /opt/pytorch/lightning-thunder/test.py:6: return x + y
# i3 = operator.add(x, y) # i3: "int 3"
# # i3 = prims.add(x, y) # i3: "int 3"
# return (i3,)
Trace bakes in the initial values for # x: "int 1", # y: "int 2" which is confusing. It would be better to just print x: int.