Skip to content

Commit 61f33b1

Browse files
committed
Changed the how unown_tensor attribute is set on TRT mod
1 parent 6446085 commit 61f33b1

File tree

4 files changed

+28
-18
lines changed

4 files changed

+28
-18
lines changed

core/runtime/execute_engine.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void setup_input_tensors(
9696
std::vector<at::Tensor> inputs,
9797
c10::intrusive_ptr<TRTEngine> compiled_engine,
9898
bool cudagraphs_enabled,
99-
bool need_cudagraphs_record) {
99+
bool shape_changed) {
100100
// this is a buffer to store shape tensor input addresses throughout the runtime scope
101101
std::list<std::vector<int64_t>> inputShapeTensorValues;
102102
std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first);
@@ -140,12 +140,14 @@ void setup_input_tensors(
140140
} else {
141141
at::Tensor contig_input = inputs[i].view(shape).contiguous();
142142
formatted_inputs.emplace_back(std::move(contig_input));
143-
143+
bool need_cudagraphs_record = cudagraphs_enabled &&
144+
(!compiled_engine->runtime_states.old_cudagraphs || shape_changed ||
145+
compiled_engine->runtime_states.context_changed);
144146
if (need_cudagraphs_record) {
145147
// Create a new persistent input buffer
146148
compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone());
147149
}
148-
if (need_cudagraphs_record or compiled_engine->allocated_outputs.size() == 0) {
150+
if (shape_changed || compiled_engine->allocated_outputs.size() == 0) {
149151
TORCHTRT_CHECK(
150152
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
151153
}
@@ -226,7 +228,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
226228
input_profiler_guard =
227229
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
228230
}
229-
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
231+
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, shape_changed);
230232
// Check if input shapes can be inferred.
231233
int32_t const io_size{compiled_engine->io_size};
232234
std::vector<char const*> names(io_size);

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def preserve_module_specs(
949949
for attr in dir(gm):
950950
if attr.startswith("_frozen_param"):
951951
delattr(gm, attr)
952-
trt_module = None
952+
953953
for name, _ in partitioned_module.named_children():
954954
submodule = getattr(partitioned_module, name)
955955
# filter on the GraphModule
@@ -1082,8 +1082,12 @@ def preserve_module_specs(
10821082
trt_module = getattr(partitioned_module, name)
10831083
trt_module.setup_engine()
10841084

1085-
if trt_module:
1086-
trt_module.set_output_tensors_as_unowned(True)
1085+
output_node = list(partitioned_module.graph.nodes)[-1]
1086+
for arg in output_node.args:
1087+
target = arg[0].target
1088+
if "_run_on_acc" not in str(target):
1089+
continue
1090+
getattr(partitioned_module, target).set_output_tensors_as_unowned(True)
10871091

10881092
# Reset settings object to user specification after fallback to global partitioning mode
10891093
if fast_partitioner_failed:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,13 @@ def setup_input_tensors(
383383
self,
384384
contiguous_inputs: List[torch.Tensor],
385385
cudagraphs_enabled: bool,
386-
need_cudagraphs_record: bool,
386+
shape_changed: bool = True,
387387
) -> None:
388+
need_cudagraphs_record = cudagraphs_enabled and (
389+
not self.runtime_states.old_cudagraphs
390+
or shape_changed
391+
or self.runtime_states.context_changed
392+
)
388393
for i, input_name in enumerate(self.input_names):
389394
if not contiguous_inputs[i].is_cuda:
390395
logger.warning(
@@ -417,9 +422,7 @@ def setup_input_tensors(
417422
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
418423
self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data)
419424
else:
420-
if (
421-
need_cudagraphs_record or self.output_tensors is None
422-
): # First time execution:
425+
if shape_changed or self.output_tensors is None:
423426
self.context.set_input_shape(
424427
input_name, tuple(contiguous_inputs[i].shape)
425428
)
@@ -490,9 +493,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
490493
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
491494

492495
self.setup_input_tensors(
493-
contiguous_inputs,
494-
self.cudagraphs_enabled,
495-
need_cudagraphs_record,
496+
contiguous_inputs, self.cudagraphs_enabled, shape_changed
496497
)
497498

498499
if shape_changed:
@@ -807,3 +808,6 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
807808
return True
808809

809810
return False
811+
812+
def are_output_tensors_unowned(self) -> bool:
813+
return self.output_tensors_are_unowned

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ def build_libtorchtrt_cxx11_abi(
195195
else:
196196
cmd.append("//:libtorchtrt")
197197

198-
# if develop:
199-
# cmd.append("--compilation_mode=dbg")
200-
# else:
201-
cmd.append("--compilation_mode=opt")
198+
if develop:
199+
cmd.append("--compilation_mode=dbg")
200+
else:
201+
cmd.append("--compilation_mode=opt")
202202
if use_dist_dir:
203203
if IS_AARCH64:
204204
cmd.append("--distdir=third_party/dist_dir/aarch64-linux-gnu")

0 commit comments

Comments
 (0)