Skip to content

Commit dd33ac3

Browse files
committed
Changed the how unown_tensor attribute is set on TRT mod
1 parent 6446085 commit dd33ac3

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def preserve_module_specs(
949949
for attr in dir(gm):
950950
if attr.startswith("_frozen_param"):
951951
delattr(gm, attr)
952-
trt_module = None
952+
953953
for name, _ in partitioned_module.named_children():
954954
submodule = getattr(partitioned_module, name)
955955
# filter on the GraphModule
@@ -1082,8 +1082,12 @@ def preserve_module_specs(
10821082
trt_module = getattr(partitioned_module, name)
10831083
trt_module.setup_engine()
10841084

1085-
if trt_module:
1086-
trt_module.set_output_tensors_as_unowned(True)
1085+
output_node = list(partitioned_module.graph.nodes)[-1]
1086+
for arg in output_node.args:
1087+
target = arg[0].target
1088+
if "acc" not in target:
1089+
continue
1090+
getattr(partitioned_module, target).set_output_tensors_as_unowned(True)
10871091

10881092
# Reset settings object to user specification after fallback to global partitioning mode
10891093
if fast_partitioner_failed:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,3 +807,6 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
807807
return True
808808

809809
return False
810+
811+
def are_output_tensors_unowned(self) -> bool:
812+
return self.output_tensors_are_unowned

0 commit comments

Comments
 (0)