Skip to content

Too many update_aliases after in-place op #2768

@shino16

Description

@shino16
import torch, thunder

@thunder.jit
def fn(a):
    a.tanh_()
    return a * a * a * a

fn(torch.randn(5, 5, device='cuda'))
print(*thunder.last_traces(fn), sep='\n\n')

Ideally this function should be compiled into a single nvfuser region, but Thunder inserts update_aliases before every multiplication, causing fusion breaks.

# ...

# Constructed by Update aliases for in-place ops
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[5, 5]"
  (t5,) = prims.update_aliases((a,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:5: 	    a.tanh_()
  t1 = ltorch.tanh_(t5)  # t1: "cuda:0 f32[5, 5]"
    # t0 = ltorch.tanh(t5)  # t0: "cuda:0 f32[5, 5]"
      # t0 = prims.tanh(t5)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t5, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  (t6,) = prims.update_aliases((t1,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:6: 	    return a * a * a * a
  t2 = ltorch.mul(t6, t6)  # t2: "cuda:0 f32[5, 5]"
    # t2 = prims.mul(t6, t6)  # t2: "cuda:0 f32[5, 5]"
  (t7,) = prims.update_aliases((t6,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:6: 	    return a * a * a * a
  t3 = ltorch.mul(t2, t7)  # t3: "cuda:0 f32[5, 5]"
    # t3 = prims.mul(t2, t7)  # t3: "cuda:0 f32[5, 5]"
  (t8,) = prims.update_aliases((t7,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:6: 	    return a * a * a * a
  t4 = ltorch.mul(t3, t8)  # t4: "cuda:0 f32[5, 5]"
    # t4 = prims.mul(t3, t8)  # t4: "cuda:0 f32[5, 5]"
  return {'output': (t4,), 'flat_args': [t8]}

# ...

# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[5, 5]"
  (t13,) = update_aliases((a,))
  del a
  [t1] = nvFusion0(t13)
    # t0 = prims.tanh(t13)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t13, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  del t13
  (t14,) = update_aliases((t1,))
  del t1
  (t15,) = update_aliases((t14,))
  [t3] = nvFusion2(t14, t15)
    # t2 = prims.mul(t14, t14)  # t2: "cuda:0 f32[5, 5]"
    # t3 = prims.mul(t2, t15)  # t3: "cuda:0 f32[5, 5]"
  del t14
  (t16,) = update_aliases((t15,))
  del t15
  [t4] = nvFusion3(t3, t16)
    # t4 = prims.mul(t3, t16)  # t4: "cuda:0 f32[5, 5]"
  del t3
  return (t4,)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions