-
Notifications
You must be signed in to change notification settings - Fork 111
Open
Description
import torch, thunder
@thunder.jit
def fn(a):
a.tanh_()
return a * a * a * a
fn(torch.randn(5, 5, device='cuda'))
print(*thunder.last_traces(fn), sep='\n\n')Ideally this function should be compiled into a single nvfuser region, but Thunder inserts update_aliases before every multiplication, causing fusion breaks.
# ...
# Constructed by Update aliases for in-place ops
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cuda:0 f32[5, 5]"
(t5,) = prims.update_aliases((a,))
# /opt/pytorch/lightning-thunder/tmp/main.py:5: a.tanh_()
t1 = ltorch.tanh_(t5) # t1: "cuda:0 f32[5, 5]"
# t0 = ltorch.tanh(t5) # t0: "cuda:0 f32[5, 5]"
# t0 = prims.tanh(t5) # t0: "cuda:0 f32[5, 5]"
# t1 = prims.copy_(t0, t5, grad_enabled=True) # t1: "cuda:0 f32[5, 5]"
(t6,) = prims.update_aliases((t1,))
# /opt/pytorch/lightning-thunder/tmp/main.py:6: return a * a * a * a
t2 = ltorch.mul(t6, t6) # t2: "cuda:0 f32[5, 5]"
# t2 = prims.mul(t6, t6) # t2: "cuda:0 f32[5, 5]"
(t7,) = prims.update_aliases((t6,))
# /opt/pytorch/lightning-thunder/tmp/main.py:6: return a * a * a * a
t3 = ltorch.mul(t2, t7) # t3: "cuda:0 f32[5, 5]"
# t3 = prims.mul(t2, t7) # t3: "cuda:0 f32[5, 5]"
(t8,) = prims.update_aliases((t7,))
# /opt/pytorch/lightning-thunder/tmp/main.py:6: return a * a * a * a
t4 = ltorch.mul(t3, t8) # t4: "cuda:0 f32[5, 5]"
# t4 = prims.mul(t3, t8) # t4: "cuda:0 f32[5, 5]"
return {'output': (t4,), 'flat_args': [t8]}
# ...
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cuda:0 f32[5, 5]"
(t13,) = update_aliases((a,))
del a
[t1] = nvFusion0(t13)
# t0 = prims.tanh(t13) # t0: "cuda:0 f32[5, 5]"
# t1 = prims.copy_(t0, t13, grad_enabled=True) # t1: "cuda:0 f32[5, 5]"
del t13
(t14,) = update_aliases((t1,))
del t1
(t15,) = update_aliases((t14,))
[t3] = nvFusion2(t14, t15)
# t2 = prims.mul(t14, t14) # t2: "cuda:0 f32[5, 5]"
# t3 = prims.mul(t2, t15) # t3: "cuda:0 f32[5, 5]"
del t14
(t16,) = update_aliases((t15,))
del t15
[t4] = nvFusion3(t3, t16)
# t4 = prims.mul(t3, t16) # t4: "cuda:0 f32[5, 5]"
del t3
return (t4,)