Skip to content

Commit af36fe7

Browse files
committed
Merge branch 'main' of ssh://github.com/Lightning-AI/lightning-thunder into make-symbolic-default
2 parents 172320d + fb989d4 commit af36fe7

16 files changed

+243
-68
lines changed

thunder/benchmarks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ def fn(self) -> Callable:
20812081
from litgpt.model import CausalSelfAttention
20822082

20832083
module = (
2084-
CausalSelfAttention(self.config)
2084+
CausalSelfAttention(self.config, 0)
20852085
.to(device=self.device, dtype=self.tdtype)
20862086
.requires_grad_(self.requires_grad)
20872087
)

thunder/benchmarks/benchmark_litgpt.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,10 @@ def setup_compile(self, model):
691691
executors.insert(0, transformer_engine_ex)
692692
transforms.insert(0, TransformerEngineTransform())
693693

694-
if "dynamo" in self.compile:
694+
if "jit" in self.compile:
695+
model = thunder.jit(model, executors=executors, transforms=transforms, **jit_options)
696+
697+
else:
695698
if self.distributed_mode == "fsdp2":
696699
print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile")
697700
import torch._dynamo.config as dynamo_config
@@ -704,10 +707,6 @@ def setup_compile(self, model):
704707
# using __wrapped__ to access the original torch.compile function did not work
705708
# so we are using the lower level torch._dynamo.optimize function
706709
model = torch._dynamo.optimize(backend=self.backend)(model)
707-
else:
708-
jit_options = {}
709-
jit_options["fp8_shard_intermediate_activation"] = self.fp8_shard_intermediate_activation
710-
model = thunder.jit(model, executors=executors, transforms=transforms, **jit_options)
711710
elif self.compile != "eager":
712711
raise ValueError(f"Invalid compile option: {self.compile}")
713712

thunder/core/interpreter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
TracebackType,
4040
)
4141

42+
import torch
43+
4244
from thunder.core.baseutils import Singleton, init_colors, extract_callable_name, is_likely_from_collections_namedtuple
4345
from thunder.core.codeutils import Positions
4446

@@ -399,6 +401,11 @@ def __init__(
399401
if with_provenance_tracking:
400402
assert isinstance(uncacheable_classes, (list, tuple))
401403
uncacheable_classes = tuple(set(uncacheable_classes) | {NoneType, int, str, float, bool, complex})
404+
if uncacheable_classes is None:
405+
uncacheable_classes = ()
406+
uncacheable_classes = tuple(
407+
set(uncacheable_classes) | {NoneType, int, str, float, bool, complex, torch.Tensor}
408+
)
402409

403410
self._uncacheable_classes = uncacheable_classes
404411

thunder/core/jit_ext.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2199,7 +2199,6 @@ def thunder_general_jit(
21992199
callbacks=general_jit_callbacks,
22002200
with_provenance_tracking=True,
22012201
unwrap_result=False,
2202-
uncacheable_classes=(torch.Tensor, int, float, str, NoneType),
22032202
record_history=compile_data.debug_options.record_interpreter_history,
22042203
)
22052204

thunder/core/update_aliases.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ def _is_view_creation_op(bsym):
5050
return bsym.sym in ltorch._syms_returning_views or bsym.sym in ltorch._syms_that_may_return_views
5151

5252

53-
def _involves_viewed_args(bsym, viewed):
54-
if bsym.sym.id == prims.PrimIDs.RETURN:
55-
return False
56-
return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args)
53+
def _involves_viewed_args(in_tensors, viewed):
54+
return bool(in_tensors.intersection(viewed))
5755

5856

5957
def _can_be_reshaped(arg, arg_to_replace):
@@ -131,6 +129,17 @@ def replace_args_with_alias_map(
131129
return no_implicit_alias_trace, view_groups
132130

133131

132+
def _unswap(swap_map, aliases):
133+
reversed_swap_map = {variableify(v): unvariableify(k) for k, v in swap_map.items()}
134+
135+
def _helper(alias):
136+
while (valias := variableify(alias)) in reversed_swap_map:
137+
alias = reversed_swap_map[valias]
138+
return variableify(alias)
139+
140+
return list(map(_helper, aliases))
141+
142+
134143
def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace:
135144
if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols):
136145
return computation_trace
@@ -166,15 +175,21 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
166175

167176
# Third pass: insert alias updates
168177
for bsym in computation_trace.bound_symbols:
169-
if _is_inplace_op(bsym) or _is_view_creation_op(bsym) or _involves_viewed_args(bsym, viewed):
170-
in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args)))
178+
in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args)))
179+
unswapped_in_tensors = _unswap(swap_map, in_tensors)
180+
if (
181+
_is_inplace_op(bsym)
182+
or _is_view_creation_op(bsym)
183+
or (bsym.sym.id != prims.PrimIDs.RETURN and _involves_viewed_args(set(unswapped_in_tensors), viewed))
184+
):
171185
if _is_inplace_op(bsym) and in_tensors:
172186
in_tensors = {in_tensors[0]}
187+
unswapped_in_tensors = {unswapped_in_tensors[0]}
173188
else:
174189
in_tensors = set(in_tensors)
175190
out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
176191
encountered.update(in_tensors)
177-
group = set(reduce(set.union, filter(lambda g: any(g.intersection(in_tensors)), view_groups), set()))
192+
group = set().union(*filter(lambda g: g.intersection(unswapped_in_tensors), view_groups))
178193
if not group or not (views_encountered := group.intersection(encountered)):
179194
# If group is empty, this is a view creation with operands that are not involved in any inplace ops.
180195
bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True))

thunder/dynamo/benchmark_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def compile(self, fn, *, inputs, **kwargs):
150150

151151
# to_source will always use symbolic trace
152152
def to_source(self, fn_name):
153-
return f"TorchInductorSpecification.torch_inductor({fn_name}, inputs)"
153+
return f"TorchInductorSpecification.torch_inductor({fn_name}, inputs, skip_symbolic_trace={self.skip_symbolic_trace})"
154154

155155
def import_str(self):
156156
return ["import torch", "from thunder.dynamo.benchmark_utils import TorchInductorSpecification"]
@@ -353,6 +353,12 @@ def time(self, stmt="pass", setup="pass", globals=None) -> Measurement:
353353
Measurement: A benchmarking result containing execution time statistics, see :class:`torch.utils.benchmark.utils.common.Measurement`.
354354
"""
355355
t = TorchBenchmarkTimer(stmt=stmt, setup=setup, globals=globals, timer=self.inner_timer)
356+
# If the timer measures an extremely short execution time, adaptive_autorange may hang.
357+
# To prevent this, we perform a preliminary run to check for such cases, e.g. measure kernel time on a cpu-only graph.
358+
# If detected, we return the time of a single run, avoiding potential hangs.
359+
pre_run = t.timeit(1)
360+
if pre_run.median <= 1e-9:
361+
return pre_run
356362
measurement = t.adaptive_autorange(
357363
threshold=self.threshold, min_run_time=self.min_run_time, max_run_time=self.max_run_time
358364
)

thunder/dynamo/report.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def write_repro(
520520
code_str = f"{code_str}\n{main_code.format(graph_name=self.graph_name)}\n{comment_str}"
521521

522522
if file_name is None:
523-
file_name = f"{self.graph_name}.py"
523+
file_name = f"{self.graph_name}_{compile_fn.name}_repro.py"
524524
with open(folder / file_name, "w") as f:
525525
print(code_str, file=f)
526526
format_python_file(folder / file_name)
@@ -633,7 +633,7 @@ def write_benchmark(
633633

634634
code_str = f"{code_str}\n{main_code.format(graph_name=self.graph_name)}\n{comment_str}"
635635
if file_name is None:
636-
file_name = f"{self.graph_name}.py"
636+
file_name = f"{self.graph_name}_{compile_fn.name}_{time_fn.name}_benchmark.py"
637637
with open(folder / file_name, "w") as f:
638638
print(code_str, file=f)
639639
format_python_file(folder / file_name)
@@ -924,7 +924,7 @@ def write_nvfuser_benchmark(self, folder, time_fn: TimerInterface, file_name=Non
924924
{comment_str}
925925
"""
926926
if file_name is None:
927-
file_name = f"{self.name}_benchmark_nvfuser.py"
927+
file_name = f"{self.name}_benchmark_nvfuser_{time_fn.name}.py"
928928
with open(folder / file_name, "w") as f:
929929
print(code_str, file=f)
930930
format_python_file(folder / file_name)
@@ -983,7 +983,7 @@ def write_inductor_benchmark(self, folder: PathLike, time_fn: TimerInterface, fi
983983
print(measurement)
984984
"""
985985
if file_name is None:
986-
file_name = f"{self.name}_benchmark_inductor.py"
986+
file_name = f"{self.name}_benchmark_inductor_{time_fn.name}.py"
987987
with open(folder / file_name, "w") as f:
988988
f.write(code_str)
989989
format_python_file(folder / file_name)
@@ -1428,22 +1428,39 @@ def save_thunderfx_repros(
14281428
Saves reproduction scripts for ThunderFX subgraphs.
14291429
14301430
This function:
1431-
1. Creates a folder structure to organize the repros
1432-
.
1433-
└── graph0
1434-
├── fusion_reports
1435-
│ ├── graph0_thunder_0_nvFusion0_forward_repro_nvfuser.py
1436-
│ ├── graph0_thunder_0_nvFusion1_forward_repro_nvfuser.py
1437-
│ ├── graph0_thunder_0_nvFusion2_backward_repro_nvfuser.py
1438-
├── graph0_thunder_0_bwd_trace.py
1439-
├── graph0_thunder_0_fwd_trace.py
1440-
└── graph0_thunder_0.py
1431+
1. Creates a folder structure to organize the repro or benchmark scripts:
1432+
1433+
If use_benchmark is True:
1434+
graph0/
1435+
├── fusion_reports/
1436+
│ ├── graph0_thunder_0_nvFusion0_forward_benchmark_inductor_KernelTime.py
1437+
│ ├── graph0_thunder_0_nvFusion0_forward_benchmark_inductor_WallTimeWithMemoryUsage.py
1438+
│ ├── graph0_thunder_0_nvFusion0_forward_benchmark_nvfuser_KernelTime.py
1439+
│ └── graph0_thunder_0_nvFusion0_forward_benchmark_nvfuser_WallTimeWithMemoryUsage.py
1440+
├── graph0_repro_torchcompile.py
1441+
├── graph0_thunder_0_bwd_trace.py
1442+
├── graph0_thunder_0_fwd_trace.py
1443+
├── graph0_thunder_0_inductor_KernelTime_benchmark.py
1444+
├── graph0_thunder_0_inductor_WallTimeWithMemoryUsage_benchmark.py
1445+
├── graph0_thunder_0_thunder_KernelTime_benchmark.py
1446+
└── graph0_thunder_0_thunder_WallTimeWithMemoryUsage_benchmark.py
1447+
1448+
If use_benchmark is False:
1449+
graph0/
1450+
├── fusion_reports/
1451+
│ ├── graph0_thunder_0_nvFusion0_forward_repro_inductor.py
1452+
│ └── graph0_thunder_0_nvFusion0_forward_repro_nvfuser.py
1453+
├── graph0_repro_torchcompile.py
1454+
├── graph0_thunder_0_fwd_trace.py
1455+
├── graph0_thunder_0_bwd_trace.py
1456+
├── graph0_thunder_0_inductor_repro.py
1457+
└── graph0_thunder_0_thunder_repro.py
14411458
14421459
2. For each Thunder FX graph and its subgraphs:
1443-
- Checks runnability if requested
1444-
- Saves benchmark or repro scripts
1445-
- Saves trace information if requested
1446-
- Saves nvFusion repros if requested
1460+
- Checks runnability if requested
1461+
- Saves benchmark or repro scripts
1462+
- Saves trace information if requested
1463+
- Saves nvFusion repros if requested
14471464
14481465
Args:
14491466
fn: The callable to analyze
@@ -1452,7 +1469,7 @@ def save_thunderfx_repros(
14521469
check_runnability: If True, checks if graphs can run with Thunder
14531470
save_fusion: If True, saves nvFusion repros
14541471
save_trace: If True, saves trace information
1455-
stream: Stream to write output log informationto
1472+
stream: Stream to write output log information to
14561473
force_overwrite: If True, overwrites existing folder at folder_path
14571474
**compile_kwargs: Keyword arguments for Thunder and torch.compile
14581475
@@ -1472,6 +1489,7 @@ def inner_fn(*args, **kwargs):
14721489
for thunder_fxgraph_report in thunder_fxgraph_reports:
14731490
graph_folder = folder_path / thunder_fxgraph_report.graph_name
14741491
graph_folder.mkdir(exist_ok=True, parents=True)
1492+
thunder_fxgraph_report.write_inductor_repro(graph_folder)
14751493
for split_report in thunder_fxgraph_report.subgraph_reports:
14761494
if check_runnability or save_trace or save_fusion:
14771495
try:
@@ -1484,22 +1502,38 @@ def inner_fn(*args, **kwargs):
14841502
continue
14851503
else:
14861504
stream.write(f"Successfully ran the {split_report.graph_name} using Thunder\n")
1505+
1506+
from torch._inductor.compile_fx import graph_returns_tuple
1507+
1508+
# torch._inductor.compile requires the output to be tuple, if not, the symbolic trace is necessary
1509+
skip_symbolic_trace = graph_returns_tuple(split_report.graph)
1510+
torchinductor = TorchInductorSpecification(skip_symbolic_trace=skip_symbolic_trace)
14871511
if use_benchmark:
1488-
split_report.write_benchmark(graph_folder, thunderjit, WallTime)
1512+
split_report.write_benchmark(graph_folder, thunderjit, WallTimeWithMemoryUsage)
1513+
split_report.write_benchmark(graph_folder, thunderjit, KernelTime)
1514+
1515+
split_report.write_benchmark(graph_folder, torchinductor, WallTimeWithMemoryUsage)
1516+
split_report.write_benchmark(graph_folder, torchinductor, KernelTime)
14891517
else:
14901518
split_report.write_repro(graph_folder, thunderjit)
1519+
split_report.write_repro(graph_folder, torchinductor)
14911520
if save_trace:
14921521
with open(graph_folder / f"{split_report.graph_name}_fwd_trace.py", "w") as f:
14931522
f.write(str(split_report.fwd_trc))
1494-
with open(graph_folder / f"{split_report.graph_name}_bwd_trace.py", "w") as f:
1495-
f.write(str(split_report.bwd_trc))
1523+
if split_report.bwd_trc is not None:
1524+
with open(graph_folder / f"{split_report.graph_name}_bwd_trace.py", "w") as f:
1525+
f.write(str(split_report.bwd_trc))
14961526
if save_fusion:
14971527
fusion_folder = graph_folder / "fusion_reports"
14981528
fusion_folder.mkdir(exist_ok=True, parents=True)
14991529
for fusion_report in split_report.fusion_reports:
15001530
if use_benchmark:
1501-
fusion_report.write_nvfuser_benchmark(fusion_folder, WallTime)
1531+
fusion_report.write_nvfuser_benchmark(fusion_folder, WallTimeWithMemoryUsage)
1532+
fusion_report.write_inductor_benchmark(fusion_folder, WallTimeWithMemoryUsage)
1533+
fusion_report.write_nvfuser_benchmark(fusion_folder, KernelTime)
1534+
fusion_report.write_inductor_benchmark(fusion_folder, KernelTime)
15021535
else:
15031536
fusion_report.write_nvfuser_repro(fusion_folder)
1537+
fusion_report.write_inductor_repro(fusion_folder)
15041538

15051539
return inner_fn

thunder/numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from numbers import Number
22
from collections.abc import Callable
33

4-
from thunder.core.langctx import langctx, Languages
4+
from thunder.core.langctxs import langctx, Languages
55
from thunder.numpy.langctx import register_method
66

77
from thunder.core.proxies import TensorProxy

thunder/tests/opinfos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6687,7 +6687,7 @@ def arange_sample_generator(op, device, dtype, requires_grad, **kwargs):
66876687
)
66886688

66896689
for case in partial_cases:
6690-
yield SampleInput(*case)
6690+
yield SampleInput(*case, dtype=dtype, device=device)
66916691

66926692

66936693
arange_opinfo = OpInfo(

0 commit comments

Comments
 (0)