diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index c6c7271e5f..448bb4e9ae 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -144,9 +144,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols): return computation_trace - swap_map = dict() - bsyms = [] - # First pass: identify inputs which are views of each other and swap them out with a default, # reshaping if necessary. computation_trace, view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices) @@ -173,10 +170,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0] viewed = set(reduce(set.union, view_groups, set())) + swap_map = dict() + swap_map_by_update_aliases = dict() + bsyms = [] + # Third pass: insert alias updates for bsym in computation_trace.bound_symbols: + bsym = bsym.from_bsym_swap_proxies(swap_map) in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))) - unswapped_in_tensors = _unswap(swap_map, in_tensors) + # We do not unswap out_tensor of an inplace bsym into in_tensor, because functional dependency is already + # captured by that reference to out_tensor + unswapped_in_tensors = _unswap(swap_map_by_update_aliases, in_tensors) if ( _is_inplace_op(bsym) or _is_view_creation_op(bsym) @@ -189,10 +193,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li in_tensors = set(in_tensors) out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs))) encountered.update(in_tensors) - group = set().union(*filter(lambda g: g.intersection(unswapped_in_tensors), view_groups)) - if not group or not (views_encountered := group.intersection(encountered)): - # If group is empty, this is a view creation with operands that are not involved in any inplace ops. - bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True)) + involved_view_groups = [g for g in view_groups if g.intersection(unswapped_in_tensors)] + involved_views = set().union(*involved_view_groups) + views_encountered = tuple(involved_views.intersection(encountered)) + + if _is_inplace_op(bsym): + # This is a hack to insert fusion break because nvFuser doesn't support mutation on intermediates + views_encountered = tuple(unswapped_in_tensors.union(views_encountered)) + + if not views_encountered: + # This is a view creation with operands that are not involved in any inplace ops. + bsyms.append(bsym) continue new_aliases = _get_new_aliases(views_encountered, computation_trace) @@ -202,14 +213,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li if has_tags(bsym, {BoundSymbolTag.BACKWARD}): update_bsym.tags.add(BoundSymbolTag.BACKWARD) bsyms.append(update_bsym) - encountered.update(out_tensors) + encountered.update(out_tensors, map(variableify, new_aliases)) bsyms.append(new_bsym) if _is_inplace_op(bsym) and len(out_tensors) == 1 and len(in_tensors) == 1: # This relies on these being one element sets (ltorch.setitem_ yields no outs). swap_map = _update_swap_map(swap_map, in_tensors.pop(), unvariableify(out_tensors.pop())) + for alias, new_alias in zip(views_encountered, new_aliases): + _update_swap_map(swap_map_by_update_aliases, alias, new_alias) + else: - bsyms.append(bsym.from_bsym_swap_proxies(swap_map)) + bsyms.append(bsym) alias_updated_trace = from_trace(computation_trace) alias_updated_trace.set_provenance(TraceProvenance("Update aliases for in-place ops")) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 4d89116153..cedcd4a8d7 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -818,10 +818,17 @@ def map_redundant(x: Any) -> Any: new_symbols = [new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols] cse_trace.bound_symbols = list(filterfalse(lambda a: a is None, new_symbols)) - return_bsym = cse_trace.bound_symbols[-1] - assert return_bsym.sym.id == prims.PrimIDs.RETURN + # TODO: Remove this and assert that return_bsym is at the end of the trace + # This is a temporary workaround until https://github.com/Lightning-AI/lightning-thunder/issues/2776 is fixed + return_bsym = None + for idx, bsym in enumerate(cse_trace.bound_symbols): + if bsym.sym.id == prims.PrimIDs.RETURN: + return_bsym = cse_trace.bound_symbols.pop(idx) + break + assert return_bsym is not None + trace_output = tree_map(map_redundant, return_bsym.args) - cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None) + cse_trace.bound_symbols.append(prims.python_return.bind(*trace_output, output=None)) end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns diff --git a/thunder/tests/test_update_aliases.py b/thunder/tests/test_update_aliases.py index 52cf6e1de9..063e733d24 100644 --- a/thunder/tests/test_update_aliases.py +++ b/thunder/tests/test_update_aliases.py @@ -19,7 +19,6 @@ NOTHING, TorchExecutor, TorchCompileExecutor, - nvFuserExecutor, requiresCUDA, xfail_if_args_tensor_mask_removed, ) @@ -361,9 +360,6 @@ def f(x, y, z): decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),), ) def test_write_to_intermediate_result(executor, device, dtype, cache): - if executor == nvFuserExecutor: - pytest.xfail("nvFuser does not support writing to intermediate results") - def fn(x): y = x.view(-1) y.add_(1) @@ -376,6 +372,24 @@ def fn(x): torch.testing.assert_close(actual, expected) +@instantiate( + dtypes=NOTHING, + decorators=(pytest.mark.parametrize("requires_grad", (False, True)),), +) +def test_write_to_viewed_intermediate(executor, device, dtype, requires_grad): + def fn(a): + b = a * 2 + c = b[:] + c.tanh_() + return a * b + + a = make_tensor((2, 3), dtype=torch.float32, device=device, requires_grad=requires_grad) + jfn = executor.make_callable(fn, fusion_type="dataflow") + actual = jfn(a) + expected = fn(a) + torch.testing.assert_close(actual, expected) + + @instantiate( dtypes=(dtypes.float32,), ) @@ -550,3 +564,39 @@ def f(x, y, z): torch.testing.assert_close(a, a_) torch.testing.assert_close(b, b_) torch.testing.assert_close(c, c_) + + +@instantiate( + dtypes=(dtypes.float32,), +) +def test_update_aliases_count(executor, device, dtype): + def f(x): + x.sin_() + return x * x * x * x + + def g(x): + x.sin_() + x.cos_() + return x * x * x * x + + def h(x): + y = x[:] + y.sin_() + return x * x * x * x + + expected_num_update_aliases = { + f: 1, # before sin_ + g: 2, # before sin_ and cos_; latter is a hack to cause fusion break + h: 5, # before sin_ and every mul + } + + for fn in [f, g]: + a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device) + a_ = a.clone().detach() + jfn = executor.make_callable(fn) + actual = jfn(a) + expected = fn(a_) + torch.testing.assert_close(actual, expected) + extrace = thunder.last_traces(jfn)[-1] + actual_num_update_aliases = len([bsym for bsym in extrace.bound_symbols if bsym.sym.name == "update_aliases"]) + assert actual_num_update_aliases == expected_num_update_aliases[fn]