Skip to content

Commit 3b45992

Browse files
committed
Add aoti_torch_item_bool and aoti_torch_assign_tensors_out shims
Add two new shim implementations for the CUDA AOTI backend: 1. aoti_torch_item_bool: Extracts a boolean value from a 0D boolean tensor. Handles both CPU and CUDA tensors by using cudaPointerGetAttributes to determine the memory location and copying from device if needed. 2. aoti_torch_assign_tensors_out: Creates a new tensor view that shares the same underlying data as the source tensor. The new tensor has the same shape, strides, and dtype as the source. Also adds: - Declaration of aoti_torch_dtype_bool() in common_shims.h - Unit tests for both new functions - Update CMakeLists.txt with new test targets - Update targets.bzl with new test targets ghstack-source-id: de89b09 ghstack-comment-id: 3676249127 Pull-Request: #16345
1 parent ec3b3f4 commit 3b45992

File tree

7 files changed

+602
-3
lines changed

7 files changed

+602
-3
lines changed

backends/aoti/common_shims.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
6464
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
6565
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
6666
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
67+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();
6768
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
6869
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
6970
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();

backends/cuda/runtime/shims/memory.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace executorch::backends::cuda {
2424

2525
using executorch::aten::SizesType;
2626
using executorch::aten::StridesType;
27+
using executorch::backends::aoti::aoti_torch_dtype_bool;
2728
using executorch::backends::aoti::aoti_torch_get_device_index;
2829
using executorch::backends::aoti::aoti_torch_get_dtype;
2930
using executorch::backends::aoti::aoti_torch_get_sizes;
@@ -797,6 +798,126 @@ AOTITorchError aoti_torch_new_tensor_handle(
797798

798799
return Error::Ok;
799800
}
801+
802+
AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) {
803+
// Validate input parameters
804+
ET_CHECK_OR_RETURN_ERROR(
805+
tensor != nullptr,
806+
InvalidArgument,
807+
"aoti_torch_item_bool failed: tensor is null");
808+
809+
ET_CHECK_OR_RETURN_ERROR(
810+
ret_value != nullptr,
811+
InvalidArgument,
812+
"aoti_torch_item_bool failed: ret_value is null");
813+
814+
// Validate that tensor dtype is bool
815+
int32_t dtype;
816+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype));
817+
818+
ET_CHECK_OR_RETURN_ERROR(
819+
dtype == aoti_torch_dtype_bool(),
820+
InvalidArgument,
821+
"aoti_torch_item_bool failed: tensor dtype is not bool (got %d)",
822+
dtype);
823+
824+
// Get the data pointer
825+
const void* data_ptr = tensor->const_data_ptr();
826+
ET_CHECK_OR_RETURN_ERROR(
827+
data_ptr != nullptr,
828+
InvalidArgument,
829+
"aoti_torch_item_bool failed: tensor data pointer is null");
830+
831+
// Check if tensor is on CUDA or CPU
832+
cudaPointerAttributes attributes{};
833+
ET_CUDA_CHECK_OR_RETURN_ERROR(
834+
cudaPointerGetAttributes(&attributes, data_ptr));
835+
836+
if (attributes.type == cudaMemoryTypeDevice) {
837+
// CUDA memory case: copy from device to host
838+
bool device_value;
839+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
840+
&device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost));
841+
*ret_value = device_value;
842+
} else {
843+
// CPU memory case: direct access
844+
const bool* bool_ptr = static_cast<const bool*>(data_ptr);
845+
*ret_value = *bool_ptr;
846+
}
847+
848+
return Error::Ok;
849+
}
850+
851+
AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
852+
// Validate input parameters
853+
ET_CHECK_OR_RETURN_ERROR(
854+
src != nullptr,
855+
InvalidArgument,
856+
"aoti_torch_assign_tensors_out failed: src is null");
857+
858+
ET_CHECK_OR_RETURN_ERROR(
859+
ret_dst != nullptr,
860+
InvalidArgument,
861+
"aoti_torch_assign_tensors_out failed: ret_dst is null");
862+
863+
// Get the data pointer from the source tensor
864+
void* data_ptr = src->mutable_data_ptr();
865+
ET_CHECK_OR_RETURN_ERROR(
866+
data_ptr != nullptr,
867+
InvalidArgument,
868+
"Source tensor has null data pointer");
869+
870+
// Check if the given memory is in the map, if not return error
871+
auto memory_it = memory_to_n_tensor.find(data_ptr);
872+
ET_CHECK_OR_RETURN_ERROR(
873+
memory_it != memory_to_n_tensor.end(),
874+
InvalidArgument,
875+
"Memory address %p is not being tracked by reference counting system",
876+
data_ptr);
877+
878+
// Get dtype from source tensor
879+
int32_t dtype = 0;
880+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype));
881+
882+
// Get sizes and strides from source tensor
883+
int64_t* sizes_ptr;
884+
int64_t* strides_ptr;
885+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr));
886+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr));
887+
888+
int64_t ndim = src->dim();
889+
890+
// Convert to vectors
891+
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
892+
std::vector<StridesType> strides =
893+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
894+
895+
// Create new tensor view that shares the same memory as source tensor
896+
std::shared_ptr<Tensor> tensor = make_tensor(
897+
sizes,
898+
data_ptr, // Share the same memory from source tensor
899+
{}, // dim_order (empty, will be auto-generated)
900+
strides,
901+
dtype_to_scalar_type(dtype));
902+
903+
ET_CHECK_OR_RETURN_ERROR(
904+
tensor != nullptr,
905+
InvalidArgument,
906+
"Failed to create tensor view in aoti_torch_assign_tensors_out");
907+
908+
// Store the tensor so it doesn't get destroyed
909+
tensors.insert(tensor);
910+
911+
*ret_dst = tensor.get();
912+
913+
// Increment the reference count for this memory address only if it is owned
914+
// by tensor
915+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
916+
? NOT_OWN
917+
: memory_to_n_tensor[data_ptr] + 1;
918+
919+
return Error::Ok;
920+
}
800921
} // extern "C"
801922

802923
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/memory.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
161161
* @return Error::Ok on success, appropriate error code on failure:
162162
* - Error::InvalidArgument: null pointers or invalid parameters
163163
*/
164-
AOTITorchError aoti_torch_new_tensor_handle(
165-
Tensor* orig_handle,
166-
Tensor** new_handle);
164+
AOTI_SHIM_EXPORT AOTITorchError
165+
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);
166+
167+
/**
168+
* Retrieves a boolean value from a 0D boolean tensor.
169+
*
170+
* This function extracts the scalar boolean value from a tensor that contains
171+
* a single boolean element. The tensor can be on either CPU or CUDA device.
172+
* For CUDA tensors, the value is copied from device to host memory.
173+
*
174+
* @param tensor Pointer to a 0D boolean tensor (must not be null)
175+
* @param ret_value Output pointer to store the boolean value (must not be null)
176+
*
177+
* @return Error::Ok on success, appropriate error code on failure:
178+
* - Error::InvalidArgument: null pointers or tensor dtype is not bool
179+
*/
180+
AOTI_SHIM_EXPORT AOTITorchError
181+
aoti_torch_item_bool(Tensor* tensor, bool* ret_value);
182+
183+
/**
184+
* Creates a new tensor that shares the same underlying data as the source
185+
* tensor.
186+
*
187+
* This function creates a new tensor view with the same shape, strides, and
188+
* dtype as the source tensor, sharing the same underlying memory. The new
189+
* tensor handle will be stored in ret_dst.
190+
*
191+
* @param src The source tensor providing the data and metadata.
192+
* @param ret_dst On output, this will point to the new tensor view.
193+
*
194+
* @return Error::Ok on success, appropriate error code on failure:
195+
* - Error::InvalidArgument: null pointers or memory not tracked
196+
*/
197+
AOTI_SHIM_EXPORT AOTITorchError
198+
aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst);
167199

168200
// Function to clear all tensors from internal storage
169201
AOTI_SHIM_EXPORT void clear_all_tensors();

backends/cuda/runtime/shims/tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ set(CUDA_SHIM_TESTS
3030
test_aoti_torch__reinterpret_tensor
3131
test_aoti_torch_copy_
3232
test_aoti_torch_new_tensor_handle
33+
test_aoti_torch_item_bool
34+
test_aoti_torch_assign_tensors_out
3335
)
3436

3537
enable_testing()

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@ def define_common_targets():
3535
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
3636
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
3737
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
38+
cuda_shim_cpp_unittest("aoti_torch_item_bool")
39+
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")

0 commit comments

Comments
 (0)