Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9d4110e
Register methods to NumberProxy
shino16 Nov 30, 2025
23a677b
Register Python's scalar ops for pythonex
shino16 Nov 30, 2025
68a5bcc
Python's floor/ceil/round/trunc returns int
shino16 Nov 30, 2025
88d971f
Support all-number inputs to where, register to pythonex
shino16 Nov 30, 2025
0d4d1db
Add tests, recompile when number type changes
shino16 Nov 30, 2025
40f4e7e
Add test_where_on_numbers
shino16 Nov 30, 2025
13a2d01
Reduce test time
shino16 Dec 1, 2025
6c925fa
Do not cache ComplexProxy, test on where
shino16 Dec 1, 2025
135a835
Reflect INT_FOR_NUMBER behavior in nvfuserex
shino16 Dec 2, 2025
592b3de
Add py_min, py_max
shino16 Nov 29, 2025
c55ac4a
Add torch.sym_min/max
shino16 Nov 29, 2025
522b3f5
Add test
shino16 Nov 29, 2025
3b73a42
Add builtin min/max lookaside
shino16 Nov 29, 2025
e2578e3
Add test
shino16 Nov 29, 2025
df72dad
Remove py_max/min and merge into prims.maximum
shino16 Dec 2, 2025
2806c40
Fix up
shino16 Dec 2, 2025
37bfd92
Apply DCE to subsymbols
beverlylytle Oct 16, 2025
5261098
try in symbol.__call__ instead
beverlylytle Nov 12, 2025
681492f
come on, ruff, it's a test
beverlylytle Nov 14, 2025
6af4d20
remove print
beverlylytle Nov 18, 2025
10d0564
respond to comments
beverlylytle Nov 21, 2025
2b7d2b3
where's my coffee
beverlylytle Nov 21, 2025
d9e4f9c
MAKE SYMBOLIC VALUES DEFAULT
shino16 Dec 10, 2025
8293bd5
Shift bsym indices
shino16 Dec 3, 2025
2e0015f
Add torch.set_autocast_enabled
shino16 Dec 3, 2025
211a2a4
Add StringProxy.__bool__
shino16 Dec 3, 2025
97a0032
Defer cotangents creation after forward pass
shino16 Dec 3, 2025
afb6a5a
.numel needs a preceding prims.shape
shino16 Dec 3, 2025
77d15b4
Prologue can't be skipped
shino16 Dec 3, 2025
a262150
Support set_grad_enabled(symbolic)
shino16 Dec 3, 2025
b9df532
Support isinstance(symbolic, type)
shino16 Dec 3, 2025
06baafb
a.numel() instead of a.numel
shino16 Dec 3, 2025
ec8ce20
Remove _take_check numel check
shino16 Dec 3, 2025
b965ef9
FIXME Allow a cycle in swap_map in OpExProcessor
shino16 Dec 3, 2025
ae5806e
Adjust test
shino16 Dec 6, 2025
4b8312f
Adjust test
shino16 Dec 6, 2025
a78049d
Adjust test
shino16 Dec 6, 2025
b1013ba
Adjust test
shino16 Dec 6, 2025
dab3ad2
Remove obsolete xfail mark
shino16 Dec 6, 2025
18d34b3
Do not treat str as symbolic values
shino16 Dec 6, 2025
634de48
Adjust test
shino16 Dec 6, 2025
c668c1c
Skip test
shino16 Dec 9, 2025
829074c
Avoid len(a.shape) in executor checker
shino16 Dec 9, 2025
d747849
Do not treat bool inputs symbolically
shino16 Dec 11, 2025
fa642ac
Adjust test
shino16 Dec 11, 2025
172320d
Update test
shino16 Dec 11, 2025
b3a78f1
Merge branch 'main' of ssh://github.com/Lightning-AI/lightning-thunde…
shino16 Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# imports unused in this file, but referenced as thunder.* elsewhere
from thunder.common import trace
import thunder.core.devices as devices
from thunder.core.proxies import Proxy
from thunder.core.proxies import NumberProxy, Proxy

from thunder.common import (
CompileData,
Expand Down Expand Up @@ -895,7 +895,12 @@ def fn_(*args, **kwargs) -> Any:
result = call_epilogue(cache_entry, result, pro_to_epi)

# Reflect the state of is_grad_enabled, as its changes were tracked only inside Thunder
pytorch.set_grad_enabled(cd.is_grad_enabled)
is_grad_enabled = cd.is_grad_enabled
if isinstance(is_grad_enabled, NumberProxy):
# TODO: Verify this assumption
assert is_grad_enabled.is_static_constrained()
is_grad_enabled = is_grad_enabled.value
pytorch.set_grad_enabled(is_grad_enabled)

cs.last_computation = cache_entry.computation_fn
return result
Expand Down
14 changes: 11 additions & 3 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Proxy,
ProxyInterface,
ProxyTag,
StringProxy,
TensorProxy,
Variable,
is_proxy_name_available,
Expand Down Expand Up @@ -303,15 +304,17 @@ def proxify(self, value: WrappedValue) -> Any:
assert p.history is not None, f"{p.history}, {value.provenance} {type(p)}"

co: CACHE_OPTIONS = get_cache_option()
if co is CACHE_OPTIONS.CONSTANT_VALUES:
if co is CACHE_OPTIONS.CONSTANT_VALUES or isinstance(uvalue, bool):
if isinstance(uvalue, str):
self.add_constraint((clang.check_string_value, p, uvalue))
elif isinstance(uvalue, slice):
self.add_constraint((clang.check_slice_value, p, uvalue))
else:
self.add_constraint((clang.check_number_type_and_value, p, uvalue))
elif co is CACHE_OPTIONS.SYMBOLIC_VALUES:
if p is not uvalue:
if isinstance(uvalue, str):
self.add_constraint((clang.check_string_value, p, uvalue))
elif p is not uvalue:
self.add_constraint((clang.check_instance, p, (type(uvalue),)))
value.register_proxy(p)
elif co not in (CACHE_OPTIONS.SAME_INPUT, CACHE_OPTIONS.NO_CACHING):
Expand Down Expand Up @@ -468,6 +471,8 @@ def _general_jit_getattr_lookaside(obj: Any, name: str, *maybe_default: Any):

@register_general_jit_lookaside(isinstance)
def _general_jit_isinstance_lookaside(obj: Any, cls: type | UnionType | tuple[type | UnionType]):
from thunder.core.baseutils import check

uobj = unwrap(obj)
ucls = unwrap(cls)
if isinstance(uobj, TensorProxy):
Expand All @@ -479,6 +484,9 @@ def _general_jit_isinstance_lookaside(obj: Any, cls: type | UnionType | tuple[ty
ucls = (ucls,)
if torch.nn.Parameter in ucls:
res = issubclass(obj.python_typ, ucls)
elif isinstance(uobj, NumberProxy):
check(uobj.value is not None, lambda: "isinstance does not support NumberProxy with no value")
res = isinstance(uobj.value, ucls)
else:
res = isinstance(uobj, ucls)

Expand Down Expand Up @@ -642,7 +650,7 @@ def _general_jit_hasattr_lookaside(obj: Any, name: str):
def _general_jit_bool_lookaside(wrapped_x: Any) -> bool | INTERPRETER_SIGNALS:
assert isinstance(wrapped_x, WrappedValue)
# It doesn't feel right to insert constraints in bool lookaside, constraints here only applies when the bool value is used in control flow.
if isinstance(wrapped_x.value, NumberProxy):
if isinstance(wrapped_x.value, (NumberProxy, StringProxy)):
if wrapped_x.value.is_dynamic():
raise NotImplementedError(f"conversion to bool is not allowed on dynamic proxy={wrapped_x.value}")
wrapped_x.value.make_static_constrained()
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _string_to_cache_option(s: str, /) -> None | CACHE_OPTIONS:
def resolve_cache_option(x: Any, /) -> CACHE_OPTIONS:
co: None | CACHE_OPTIONS
if x is None:
co = CACHE_OPTIONS.CONSTANT_VALUES
co = CACHE_OPTIONS.SYMBOLIC_VALUES
elif isinstance(x, CACHE_OPTIONS):
co = x
elif isinstance(x, str):
Expand Down
9 changes: 7 additions & 2 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def _infer_tensor_properties(
else:
# deferred computation of numel
# TODO: similar to how `shape` is handled, this should be CSE or lifted for efficiency
_numel = lambda *args: reduce(operator.mul, _shape, 1)
_numel = lambda self: reduce(operator.mul, self.shape, 1)

# TODO Alias rank to ndim?
_ndim = len(_shape)
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def __init__(
self._device,
self._dtype,
self._true_dtype,
self._numel,
_numel,
self._ndim,
self._requires_grad,
self._grad,
Expand All @@ -1482,6 +1482,11 @@ def __init__(
thunder_fsdp_padding_size,
)

if not using_symbolic_values():
self._numel = _numel
else:
self._numel = lambda self=self: _numel(self)

# NOTE The following properties DO NOT depend on the language context or record
# themselves into the trace, so they can be used when working with tensor proxies
# outside of a trace or language context
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def rematerialize(trace: TraceCtx) -> TraceCtx:
computed_cuts_for_producers[producer] += cut

rematerialized_trace = from_trace(trace)
rematerialized_trace.bound_symbols = tuple(new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols)
rematerialized_trace.bound_symbols = list(new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols)

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand Down
18 changes: 15 additions & 3 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def tag_tensorproxy_output_as_detached(proxy):
exception_type=AssertionError,
)

# When using symbolic values, there may be duplicate prims.eq and prims.shape subsymbols that can be removed.
from thunder.core.transform_common import dce_bsyms

subsymbols = dce_bsyms(subsymbols, result)
bsym = bsym.from_bsym(subsymbols=subsymbols)

symbols_list.append(bsym)
return result

Expand Down Expand Up @@ -447,6 +453,7 @@ def from_bsym_swap_proxies(
skip_inputs: bool = False,
skip_output: bool = False,
skip_subsymbols: bool = False,
allow_cycles: bool = False,
) -> BoundSymbol:
"""Create a new :class:`BoundSymbol` with its inputs, output, and subsymbols updated with ``swap_map``.

Expand Down Expand Up @@ -481,9 +488,14 @@ def swap(c):
while vfa in swap_map:
if swap_map[vfa] is fa:
break
baseutils.check(
vfa not in visited, lambda: f"Detected a cycle while swapping; the cycle includes {visited}"
)

if vfa in visited:
baseutils.check(
allow_cycles,
lambda: f"Detected a cycle while swapping; the cycle includes {visited}",
)
break

visited.add(vfa)

fa = swap_map[vfa]
Expand Down
9 changes: 5 additions & 4 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,12 @@ class TraceSubstitutionProcessor:

NULL = object()

def __init__(self, trace, *args, **kwargs):
def __init__(self, trace, allow_swap_map_cycles=False, *args, **kwargs):
self.env = {}
self.trace = trace
self.new_trace = from_trace(self.trace)
self.have_processed_args = False
self.allow_swap_map_cycles = allow_swap_map_cycles

def read(self, x: VariableInterface | Any) -> Any:
if isinstance(x, VariableInterface):
Expand Down Expand Up @@ -398,9 +399,9 @@ def __call__(self):
for new_bsym in self.new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
self.new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(self.swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
)
new_bsym.from_bsym_swap_proxies(
self.swap_map, allow_cycles=self.allow_swap_map_cycles
).from_bsym(source_filename=bsym.source_filename, source_positions=bsym.source_positions)
)

result = tree_map(self.do_swap, self.replacement_result)
Expand Down
39 changes: 30 additions & 9 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,32 @@ def keep_or_swap(p):
# that only produce non-proxy objects
# NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return
# all the needed proxies of the input trace
def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
start_time_ns = time.perf_counter_ns()
def dce_bsyms(
bsyms: list[BoundSymbolInterface],
output: Any,
needed_proxies: None | set[Variable] = None,
) -> Trace | list[BoundSymbolInterface]:
"""Runs a Dead Code Elimination (DCE) pass

Args:
bsyms: The list of bound symbols to run the DCE pass on.
needed_proxies: The set of variables to keep.
output: The output of the list of bound symbols.

producer_map: ProxyDict = producers(trace)
Returns:
The list of bound symbols after the DCE pass.
"""
producer_map: ProxyDict = producers(bsyms)

flat_trace_outputs, _ = tree_flatten(trace.output)
flat_trace_outputs, _ = tree_flatten(output)
if needed_proxies is None:
needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)))
else:
needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)))
dced = []

bsym: BoundSymbol
for bsym in reversed(trace.bound_symbols):
for bsym in reversed(bsyms):
# Preserves symbols that should never be collected
if has_tags(bsym, {prims.OpTags.DONT_DCE}):
needed = True
Expand All @@ -182,19 +194,28 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
for x in nbsym.flat_proxy_args:
needed_proxies.add(variableify(x))

dcetrace = from_trace(trace)
dced_bound_symbols = list(reversed(dced))
# duplicate number proxies happen with the symbolic shapes and are
# not covered by the above (due to being in tuples?).
dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
dcetrace.bound_symbols = dced_bound_symbols

return dced_bound_symbols


def dce(trace: Trace, needed_proxies: set[Variable] = None) -> Trace:
start_time_ns = time.perf_counter_ns()

bsyms = trace.bound_symbols
dced_bsyms = dce_bsyms(bsyms, trace.output, needed_proxies)
result = from_trace(trace)
result.bound_symbols = dced_bsyms

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)"))

return dcetrace
result.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)"))
return result


#
Expand Down
34 changes: 19 additions & 15 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,10 +3031,24 @@ def vjp_call(primals, cotangents, trace: Trace, **kwargs):
primals = (primals,)

result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
check(
len(result) == len(cotangents) if isinstance(result, Sequence) else True,
lambda: f"Expected cotangents to be a sequence of length {len(result)}, got a sequence of length {len(cotangents)}",
)

if cotangents is None:

def ones_like(x):
if isinstance(x, TensorProxy):
return full_like(x, fill_value=1)
elif isinstance(x, NumberProxy):
return type(x.value)(1)
else:
return None

cotangents = tree_map(lambda v: ones_like(v), result)
else:
check(
len(result) == len(cotangents) if isinstance(result, Sequence) else True,
lambda: f"Expected cotangents to be a sequence of length {len(result)}, got a sequence of length {len(cotangents)}",
)

return result, backward_pass(env, trace, cotangents)


Expand Down Expand Up @@ -3075,18 +3089,8 @@ def value_and_grad(func):
func (Callable): Function to be differentiated.
"""

def ones_like(x):
if isinstance(x, TensorProxy):
return full_like(x, fill_value=1)
elif isinstance(x, NumberProxy):
return type(x.value)(1)
else:
return None

def _value_and_grad(*args, **kwargs):
trace = construct_trace()(func, *args, **kwargs)
cotangents = tree_map(lambda v: ones_like(v), trace.output)
return vjp(func)(args, cotangents, **kwargs)
return vjp(func)(args, None, **kwargs)

return _value_and_grad

Expand Down
15 changes: 9 additions & 6 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,20 +259,20 @@ def get_backed_value(s):
return tuple(map(get_backed_value, vals))


def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]:
def get_proxy_inputs_from_node(node: torch.fx.Node, tracectx) -> tuple[tuple, dict]:
"""Creates proxy inputs from a torch.fx.Node for use with Thunder.

This function generates proxy inputs for a given torch.fx.Node

Args:
node (torch.fx.Node): The FX graph node to create proxy inputs for.
tracectx (TraceCtx): The trace context to use to generate proxy inputs.
"""
import thunder
from thunder.core.trace import TraceCtx
from thunder.core.proxies import proxy

# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
with thunder.core.trace.tracectx(tracectx):

def make_input_proxy(arg_node):
# This is a Node in the graph representing a Tensor or tuple of Tensors or
Expand Down Expand Up @@ -380,8 +380,10 @@ def _run_with_cache_info():
cache_info["default_dtype"] = torch.get_default_dtype()
cache_info["default_device"] = torch.get_default_device()

tracectx = TraceCtx()

try:
proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node)
proxy_args, proxy_kwargs = get_proxy_inputs_from_node(node, tracectx)
except Exception as e:
return False, SplitReason(
SplitReasonType.EXCEPTION_PROXY_THUNDER_OP,
Expand All @@ -395,7 +397,7 @@ def _run_with_cache_info():
else thunder_symbol
)
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
with thunder.core.trace.tracectx(tracectx):
try:
function_to_run(*proxy_args, **proxy_kwargs)
except Exception as e:
Expand Down Expand Up @@ -478,6 +480,7 @@ def is_node_supported_by_thunder(
"""
Determine whether thunder can execute the operation described by this node.
"""
from thunder.core.trace import TraceCtx
# Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node
# Each Node has a function specified by its op property
# Below are the details for the ones this function is interested in -
Expand Down Expand Up @@ -555,7 +558,7 @@ def is_node_supported_by_thunder(
if torchctx.has_method(node.target):
# `torchctx.get_method` requires args and kwargs to resolve which overload of the method is picked.
try:
args, kwargs = get_proxy_inputs_from_node(node)
args, kwargs = get_proxy_inputs_from_node(node, TraceCtx())
except Exception as e:
return False, SplitReason(
SplitReasonType.EXCEPTION_PROXY_THUNDER_OP,
Expand Down
Loading
Loading