Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions thunder/core/update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
if not any(_is_inplace_op(bsym) for bsym in computation_trace.bound_symbols):
return computation_trace

swap_map = dict()
bsyms = []

# First pass: identify inputs which are views of each other and swap them out with a default,
# reshaping if necessary.
computation_trace, view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices)
Expand All @@ -173,10 +170,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0]
viewed = set(reduce(set.union, view_groups, set()))

swap_map = dict()
swap_map_by_update_aliases = dict()
bsyms = []

# Third pass: insert alias updates
for bsym in computation_trace.bound_symbols:
bsym = bsym.from_bsym_swap_proxies(swap_map)
in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args)))
unswapped_in_tensors = _unswap(swap_map, in_tensors)
# We do not unswap out_tensor of an inplace bsym into in_tensor, because functional dependency is already
# captured by that reference to out_tensor
unswapped_in_tensors = _unswap(swap_map_by_update_aliases, in_tensors)
if (
_is_inplace_op(bsym)
or _is_view_creation_op(bsym)
Expand All @@ -189,10 +193,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
in_tensors = set(in_tensors)
out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
encountered.update(in_tensors)
group = set().union(*filter(lambda g: g.intersection(unswapped_in_tensors), view_groups))
if not group or not (views_encountered := group.intersection(encountered)):
# If group is empty, this is a view creation with operands that are not involved in any inplace ops.
bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True))
involved_view_groups = [g for g in view_groups if g.intersection(unswapped_in_tensors)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: wouldn't this call g.intersect len(view_groups) times?

involved_views = set().union(*involved_view_groups)
views_encountered = tuple(involved_views.intersection(encountered))

if _is_inplace_op(bsym):
# This is a hack to insert fusion break because nvFuser doesn't support mutation on intermediates
views_encountered = tuple(unswapped_in_tensors.union(views_encountered))

if not views_encountered:
# This is a view creation with operands that are not involved in any inplace ops.
bsyms.append(bsym)
continue

new_aliases = _get_new_aliases(views_encountered, computation_trace)
Expand All @@ -202,14 +213,17 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
if has_tags(bsym, {BoundSymbolTag.BACKWARD}):
update_bsym.tags.add(BoundSymbolTag.BACKWARD)
bsyms.append(update_bsym)
encountered.update(out_tensors)
encountered.update(out_tensors, map(variableify, new_aliases))
bsyms.append(new_bsym)
if _is_inplace_op(bsym) and len(out_tensors) == 1 and len(in_tensors) == 1:
# This relies on these being one element sets (ltorch.setitem_ yields no outs).
swap_map = _update_swap_map(swap_map, in_tensors.pop(), unvariableify(out_tensors.pop()))

for alias, new_alias in zip(views_encountered, new_aliases):
_update_swap_map(swap_map_by_update_aliases, alias, new_alias)

else:
bsyms.append(bsym.from_bsym_swap_proxies(swap_map))
bsyms.append(bsym)

alias_updated_trace = from_trace(computation_trace)
alias_updated_trace.set_provenance(TraceProvenance("Update aliases for in-place ops"))
Expand Down
13 changes: 10 additions & 3 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,10 +818,17 @@ def map_redundant(x: Any) -> Any:
new_symbols = [new_bsyms.get(bsym, bsym) for bsym in trace.bound_symbols]
cse_trace.bound_symbols = list(filterfalse(lambda a: a is None, new_symbols))

return_bsym = cse_trace.bound_symbols[-1]
assert return_bsym.sym.id == prims.PrimIDs.RETURN
# TODO: Remove this and assert that return_bsym is at the end of the trace
# This is a temporary workaround until https://github.com/Lightning-AI/lightning-thunder/issues/2776 is fixed
return_bsym = None
for idx, bsym in enumerate(cse_trace.bound_symbols):
if bsym.sym.id == prims.PrimIDs.RETURN:
return_bsym = cse_trace.bound_symbols.pop(idx)
break
assert return_bsym is not None

trace_output = tree_map(map_redundant, return_bsym.args)
cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None)
cse_trace.bound_symbols.append(prims.python_return.bind(*trace_output, output=None))

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand Down
58 changes: 54 additions & 4 deletions thunder/tests/test_update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
NOTHING,
TorchExecutor,
TorchCompileExecutor,
nvFuserExecutor,
requiresCUDA,
xfail_if_args_tensor_mask_removed,
)
Expand Down Expand Up @@ -361,9 +360,6 @@ def f(x, y, z):
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
)
def test_write_to_intermediate_result(executor, device, dtype, cache):
if executor == nvFuserExecutor:
pytest.xfail("nvFuser does not support writing to intermediate results")

def fn(x):
y = x.view(-1)
y.add_(1)
Expand All @@ -376,6 +372,24 @@ def fn(x):
torch.testing.assert_close(actual, expected)


@instantiate(
dtypes=NOTHING,
decorators=(pytest.mark.parametrize("requires_grad", (False, True)),),
)
def test_write_to_viewed_intermediate(executor, device, dtype, requires_grad):
def fn(a):
b = a * 2
c = b[:]
c.tanh_()
return a * b

a = make_tensor((2, 3), dtype=torch.float32, device=device, requires_grad=requires_grad)
jfn = executor.make_callable(fn, fusion_type="dataflow")
actual = jfn(a)
expected = fn(a)
torch.testing.assert_close(actual, expected)


@instantiate(
dtypes=(dtypes.float32,),
)
Expand Down Expand Up @@ -550,3 +564,39 @@ def f(x, y, z):
torch.testing.assert_close(a, a_)
torch.testing.assert_close(b, b_)
torch.testing.assert_close(c, c_)


@instantiate(
dtypes=(dtypes.float32,),
)
def test_update_aliases_count(executor, device, dtype):
def f(x):
x.sin_()
return x * x * x * x

def g(x):
x.sin_()
x.cos_()
return x * x * x * x

def h(x):
y = x[:]
y.sin_()
return x * x * x * x

expected_num_update_aliases = {
f: 1, # before sin_
g: 2, # before sin_ and cos_; latter is a hack to cause fusion break
h: 5, # before sin_ and every mul
}

for fn in [f, g]:
a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device)
a_ = a.clone().detach()
jfn = executor.make_callable(fn)
actual = jfn(a)
expected = fn(a_)
torch.testing.assert_close(actual, expected)
extrace = thunder.last_traces(jfn)[-1]
actual_num_update_aliases = len([bsym for bsym in extrace.bound_symbols if bsym.sym.name == "update_aliases"])
assert actual_num_update_aliases == expected_num_update_aliases[fn]
Loading