Skip to content

Conversation

@shino16
Copy link
Collaborator

@shino16 shino16 commented Nov 26, 2025

Fixes #2768. The primary role of prims.update_aliases is to establish relative ordering between bsyms involving aliases and mutation. But when relative ordering is already established by functional dependency, inserting prims.update_aliases does no good and only causes unnecessary fusion break.

Since variable substitution via swap_map creates extra functional dependencies, this should be done before deciding whether we need to insert prims.update_aliases. This eliminates the unnecessary prims.update_aliases.

But there is an exception. In current behavior we insert prims.update_aliases before every in-place op, and we must keep this regardless. nvFuser restricts mutation on tensors that are not inputs to the fused region, and we have no mechanism to comply with this rule other than to break fusion for every mutation. See #2768 (comment) for details.

@shino16 shino16 force-pushed the remove-excess-udpate_aliases branch from 18c18a7 to 92663e5 Compare November 26, 2025 17:00
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request aims to remove unnecessary repetition of update_aliases calls in the alias update logic. The changes include:

  • Refactoring the condition logic in insert_alias_updates to handle inplace operations differently
  • Adding special handling for inplace ops to force insertion of update_aliases (as a workaround for nvFuser limitations)
  • Adding a new test to verify the expected number of update_aliases calls

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
thunder/core/update_aliases.py Refactors the alias update insertion logic, moving the from_bsym_swap_proxies call earlier and adding special handling for inplace operations
thunder/tests/test_update_aliases.py Adds test test_update_aliases_count to verify the expected number of update_aliases calls for functions with different numbers of inplace operations
Comments suppressed due to low confidence (1)

thunder/core/update_aliases.py:192

  • Potential double-swapping issue: bsym is swapped on line 168 with skip_output=True, then later on line 192, from_bsym_swap_proxies(swap_map) is called again on the same (already-swapped) bsym with an updated swap_map. This could lead to incorrect proxy substitution. Consider storing the original bsym before swapping on line 168, or restructure the logic to avoid calling from_bsym_swap_proxies twice on the same symbol.
    for bsym in computation_trace.bound_symbols:
        if _is_inplace_op(bsym) or _is_view_creation_op(bsym) or _involves_viewed_args(bsym, viewed):
            bsym = bsym.from_bsym_swap_proxies(swap_map, skip_output=True)
            in_tensors = list(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args)))
            if _is_inplace_op(bsym) and in_tensors:
                in_tensors = {in_tensors[0]}
            else:
                in_tensors = set(in_tensors)
            out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
            encountered.update(in_tensors)
            group = set(reduce(set.union, filter(lambda g: any(g.intersection(in_tensors)), view_groups), set()))
            views_encountered = group.intersection(encountered)

            if _is_inplace_op(bsym):
                # Super-hacky workaround to insert fusion break because nvFuser doesn't support mutation on intermediates
                # See https://github.com/Lightning-AI/lightning-thunder/issues/2768#issuecomment-3581908434
                views_encountered = in_tensors

            if not views_encountered:
                # This is a view creation with operands that are not involved in any inplace ops.
                bsyms.append(bsym)
                continue

            new_aliases = _get_new_aliases(views_encountered, computation_trace)


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@shino16 shino16 marked this pull request as ready for review November 26, 2025 18:15
@shino16 shino16 changed the title [WIP] Remove unnecessary repetition of update_aliases Remove unnecessary repetition of update_aliases Nov 26, 2025
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvFuser restricts mutation on tensors that are not inputs to the fused region, and we have no mechanism to comply with this rule other than to break fusion for every mutation.

@jjsjann123 would there be a plan to relax this?

Copy link
Collaborator

@mattteochen mattteochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!

@shino16
Copy link
Collaborator Author

shino16 commented Nov 27, 2025

The test failures are due to #2776

@shino16
Copy link
Collaborator Author

shino16 commented Nov 27, 2025

Will be fixed by #2777

Copy link
Collaborator

@beverlylytle beverlylytle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good change! However, there will be (hopefully resolvable) conflicts with #2769. Which should be merged first? (Ticking the request changes box to forestall merging before this question is answered.)

@shino16
Copy link
Collaborator Author

shino16 commented Dec 4, 2025

I would prefer merging this PR first. There appears to be no conflict between these two PRs, but I will take extra care.

@shino16
Copy link
Collaborator Author

shino16 commented Dec 5, 2025

I found a regression:

import torch, thunder

def fn(a):
    b = a * 2
    c = b[:]
    c.tanh_()
    return a * b

jfn = thunder.jit(fn)
x = torch.randn(6, device="cpu", requires_grad=True)
y = jfn(x)
y_ref = fn(x)
print(thunder.last_traces(jfn)[-1])
torch.testing.assert_close(y, y_ref)
# AssertionError: Tensor-likes are not close!

Since this follows the pattern of "mutate a view -> use its alias", I see this as part of #2766. The only difference is that this repro mutates on an intermediate, not an input. I added this as a xfailed test.

This will be fixed by @beverlylytle's draft #2769 (or the patch #2766 (comment)).

@shino16
Copy link
Collaborator Author

shino16 commented Dec 12, 2025

CI failures at 4cee1ad:

____________________ test_complex_backward_custom_autograd _____________________

...

        jf = thunder_jit(f, fusion_type="dataflow")
    
        x = torch.ones(2, 3, device="cuda", requires_grad=True)
    
        # This should not raise an error about variables referenced before assignment.
>       jf(x)

thunder/tests/test_jit_general.py:1233: 

...

        cse_trace.bound_symbols = list(filterfalse(lambda a: a is None, new_symbols))
    
        return_bsym = cse_trace.bound_symbols[-1]
>       assert return_bsym.sym.id == prims.PrimIDs.RETURN
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       AssertionError

thunder/executors/nvfuserex_impl.py:821: AssertionError

This error is meant to be fixed in #2777, but I want this PR to be merged before testing #2777.

I implemented a quick workaround that forcefully put return bsym at the end of the trace.

This PR should be ready to merge once the CI passes.

Comment on lines 820 to 825
return_bsym = cse_trace.bound_symbols[-1]
assert return_bsym.sym.id == prims.PrimIDs.RETURN
return_bsym = None
for idx, bsym in enumerate(cse_trace.bound_symbols):
if bsym.sym.id == prims.PrimIDs.RETURN:
return_bsym = cse_trace.bound_symbols.pop(idx)
break
assert return_bsym is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the code changes first before looking at the discussion and this was very alarming to me. Could you add a TODO comment about this being removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I appreciate that kind of feedback. Yes, it's a rough solution indeed...

@shino16
Copy link
Collaborator Author

shino16 commented Dec 16, 2025

Hi @beverlylytle, do you think this PR is good to merge? We must wait for #2805 for CI though.

@beverlylytle
Copy link
Collaborator

Yes, I think it is.

@shino16
Copy link
Collaborator Author

shino16 commented Dec 17, 2025

@KaelanDt This PR is ready for your stamp. Thank you!

if not group or not (views_encountered := group.intersection(encountered)):
# If group is empty, this is a view creation with operands that are not involved in any inplace ops.
bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True))
involved_view_groups = [g for g in view_groups if g.intersection(unswapped_in_tensors)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: wouldn't this call g.intersect len(view_groups) times?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Too many update_aliases after in-place op

4 participants