Skip to content

[symbolic values] Trace bakes in the initial value of input used for trace construction #2789

@kshitij12345

Description

@kshitij12345

Repro:

import thunder
import torch
import math

def fn(x, y):
    return x + y

tfn = thunder.jit(fn, cache="symbolic values")

tfn(1, 2)

print(thunder.last_traces(tfn)[-1])
# def computation(x, y):
#   # x: "int 1"
#   # y: "int 2"

#   # /opt/pytorch/lightning-thunder/test.py:6:       return x + y
#   i3 = operator.add(x, y)  # i3: "int 3"
#     # i3 = prims.add(x, y)  # i3: "int 3"
#   return (i3,)

tfn(1, 3)

print(thunder.last_traces(tfn)[-1])
# def computation(x, y):
#   # x: "int 1"
#   # y: "int 2"

#   # /opt/pytorch/lightning-thunder/test.py:6:       return x + y
#   i3 = operator.add(x, y)  # i3: "int 3"
#     # i3 = prims.add(x, y)  # i3: "int 3"
#   return (i3,)

Trace bakes in the initial values for # x: "int 1", # y: "int 2" which is confusing. It would be better to just print x: int.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions