-
Notifications
You must be signed in to change notification settings - Fork 110
Remove unnecessary repetition of update_aliases
#2772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
80ee21e to
68ae8df
Compare
18c18a7 to
92663e5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request aims to remove unnecessary repetition of update_aliases calls in the alias update logic. The changes include:
- Refactoring the condition logic in
insert_alias_updatesto handle inplace operations differently - Adding special handling for inplace ops to force insertion of
update_aliases(as a workaround for nvFuser limitations) - Adding a new test to verify the expected number of
update_aliasescalls
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| thunder/core/update_aliases.py | Refactors the alias update insertion logic, moving the from_bsym_swap_proxies call earlier and adding special handling for inplace operations |
| thunder/tests/test_update_aliases.py | Adds test test_update_aliases_count to verify the expected number of update_aliases calls for functions with different numbers of inplace operations |
Comments suppressed due to low confidence (1)
thunder/core/update_aliases.py:192
- Potential double-swapping issue:
bsymis swapped on line 168 withskip_output=True, then later on line 192,from_bsym_swap_proxies(swap_map)is called again on the same (already-swapped)bsymwith an updatedswap_map. This could lead to incorrect proxy substitution. Consider storing the originalbsymbefore swapping on line 168, or restructure the logic to avoid callingfrom_bsym_swap_proxiestwice on the same symbol.
for bsym in computation_trace.bound_symbols:
if _is_inplace_op(bsym) or _is_view_creation_op(bsym) or _involves_viewed_args(bsym, viewed):
bsym = bsym.from_bsym_swap_proxies(swap_map, skip_output=True)
in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args)))
if _is_inplace_op(bsym) and in_tensors:
in_tensors = {in_tensors[0]}
else:
in_tensors = set(in_tensors)
out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
encountered.update(in_tensors)
group = set(reduce(set.union, filter(lambda g: any(g.intersection(in_tensors)), view_groups), set()))
views_encountered = group.intersection(encountered)
if _is_inplace_op(bsym):
# Super-hacky workaround to insert fusion break because nvFuser doesn't support mutation on intermediates
# See https://github.com/Lightning-AI/lightning-thunder/issues/2768#issuecomment-3581908434
views_encountered = in_tensors
if not views_encountered:
# This is a view creation with operands that are not involved in any inplace ops.
bsyms.append(bsym)
continue
new_aliases = _get_new_aliases(views_encountered, computation_trace)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
update_aliasesupdate_aliases
crcrpar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvFuser restricts mutation on tensors that are not inputs to the fused region, and we have no mechanism to comply with this rule other than to break fusion for every mutation.
@jjsjann123 would there be a plan to relax this?
mattteochen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks!
|
The test failures are due to #2776 |
|
Will be fixed by #2777 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good change! However, there will be (hopefully resolvable) conflicts with #2769. Which should be merged first? (Ticking the request changes box to forestall merging before this question is answered.)
|
I would prefer merging this PR first. There appears to be no conflict between these two PRs, but I will take extra care. |
|
I found a regression: import torch, thunder
def fn(a):
b = a * 2
c = b[:]
c.tanh_()
return a * b
jfn = thunder.jit(fn)
x = torch.randn(6, device="cpu", requires_grad=True)
y = jfn(x)
y_ref = fn(x)
print(thunder.last_traces(jfn)[-1])
torch.testing.assert_close(y, y_ref)
# AssertionError: Tensor-likes are not close!Since this follows the pattern of "mutate a view -> use its alias", I see this as part of #2766. The only difference is that this repro mutates on an intermediate, not an input. I added this as a xfailed test. This will be fixed by @beverlylytle's draft #2769 (or the patch #2766 (comment)). |
|
CI failures at 4cee1ad: This error is meant to be fixed in #2777, but I want this PR to be merged before testing #2777. I implemented a quick workaround that forcefully put return bsym at the end of the trace. This PR should be ready to merge once the CI passes. |
| return_bsym = cse_trace.bound_symbols[-1] | ||
| assert return_bsym.sym.id == prims.PrimIDs.RETURN | ||
| return_bsym = None | ||
| for idx, bsym in enumerate(cse_trace.bound_symbols): | ||
| if bsym.sym.id == prims.PrimIDs.RETURN: | ||
| return_bsym = cse_trace.bound_symbols.pop(idx) | ||
| break | ||
| assert return_bsym is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the code changes first before looking at the discussion and this was very alarming to me. Could you add a TODO comment about this being removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I appreciate that kind of feedback. Yes, it's a rough solution indeed...
|
Hi @beverlylytle, do you think this PR is good to merge? We must wait for #2805 for CI though. |
|
Yes, I think it is. |
|
@KaelanDt This PR is ready for your stamp. Thank you! |
| if not group or not (views_encountered := group.intersection(encountered)): | ||
| # If group is empty, this is a view creation with operands that are not involved in any inplace ops. | ||
| bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True)) | ||
| involved_view_groups = [g for g in view_groups if g.intersection(unswapped_in_tensors)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: wouldn't this call g.intersect len(view_groups) times?
Fixes #2768. The primary role of
prims.update_aliasesis to establish relative ordering between bsyms involving aliases and mutation. But when relative ordering is already established by functional dependency, insertingprims.update_aliasesdoes no good and only causes unnecessary fusion break.Since variable substitution via
swap_mapcreates extra functional dependencies, this should be done before deciding whether we need to insertprims.update_aliases. This eliminates the unnecessaryprims.update_aliases.But there is an exception. In current behavior we insert
prims.update_aliasesbefore every in-place op, and we must keep this regardless. nvFuser restricts mutation on tensors that are not inputs to the fused region, and we have no mechanism to comply with this rule other than to break fusion for every mutation. See #2768 (comment) for details.