Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
Expand Down
121 changes: 121 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace executorch::backends::cuda {

using executorch::aten::SizesType;
using executorch::aten::StridesType;
using executorch::backends::aoti::aoti_torch_dtype_bool;
using executorch::backends::aoti::aoti_torch_get_device_index;
using executorch::backends::aoti::aoti_torch_get_dtype;
using executorch::backends::aoti::aoti_torch_get_sizes;
Expand Down Expand Up @@ -797,6 +798,126 @@ AOTITorchError aoti_torch_new_tensor_handle(

return Error::Ok;
}

AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) {
// Validate input parameters
ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr,
InvalidArgument,
"aoti_torch_item_bool failed: tensor is null");

ET_CHECK_OR_RETURN_ERROR(
ret_value != nullptr,
InvalidArgument,
"aoti_torch_item_bool failed: ret_value is null");

// Validate that tensor dtype is bool
int32_t dtype;
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype));

ET_CHECK_OR_RETURN_ERROR(
dtype == aoti_torch_dtype_bool(),
InvalidArgument,
"aoti_torch_item_bool failed: tensor dtype is not bool (got %d)",
dtype);

// Get the data pointer
const void* data_ptr = tensor->const_data_ptr();
ET_CHECK_OR_RETURN_ERROR(
data_ptr != nullptr,
InvalidArgument,
"aoti_torch_item_bool failed: tensor data pointer is null");

// Check if tensor is on CUDA or CPU
cudaPointerAttributes attributes{};
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&attributes, data_ptr));

if (attributes.type == cudaMemoryTypeDevice) {
// CUDA memory case: copy from device to host
bool device_value;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
&device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost));
*ret_value = device_value;
} else {
// CPU memory case: direct access
const bool* bool_ptr = static_cast<const bool*>(data_ptr);
*ret_value = *bool_ptr;
}

return Error::Ok;
}

AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
// Validate input parameters
ET_CHECK_OR_RETURN_ERROR(
src != nullptr,
InvalidArgument,
"aoti_torch_assign_tensors_out failed: src is null");

ET_CHECK_OR_RETURN_ERROR(
ret_dst != nullptr,
InvalidArgument,
"aoti_torch_assign_tensors_out failed: ret_dst is null");

// Get the data pointer from the source tensor
void* data_ptr = src->mutable_data_ptr();
ET_CHECK_OR_RETURN_ERROR(
data_ptr != nullptr,
InvalidArgument,
"Source tensor has null data pointer");

// Check if the given memory is in the map, if not return error
auto memory_it = memory_to_n_tensor.find(data_ptr);
ET_CHECK_OR_RETURN_ERROR(
memory_it != memory_to_n_tensor.end(),
InvalidArgument,
"Memory address %p is not being tracked by reference counting system",
data_ptr);

// Get dtype from source tensor
int32_t dtype = 0;
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype));

// Get sizes and strides from source tensor
int64_t* sizes_ptr;
int64_t* strides_ptr;
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr));
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr));

int64_t ndim = src->dim();

// Convert to vectors
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
std::vector<StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create new tensor view that shares the same memory as source tensor
std::shared_ptr<Tensor> tensor = make_tensor(
sizes,
data_ptr, // Share the same memory from source tensor
{}, // dim_order (empty, will be auto-generated)
strides,
dtype_to_scalar_type(dtype));

ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr,
InvalidArgument,
"Failed to create tensor view in aoti_torch_assign_tensors_out");

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);

*ret_dst = tensor.get();

// Increment the reference count for this memory address only if it is owned
// by tensor
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
? NOT_OWN
: memory_to_n_tensor[data_ptr] + 1;

return Error::Ok;
}
} // extern "C"

} // namespace executorch::backends::cuda
38 changes: 35 additions & 3 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or invalid parameters
*/
AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);

/**
* Retrieves a boolean value from a 0D boolean tensor.
*
* This function extracts the scalar boolean value from a tensor that contains
* a single boolean element. The tensor can be on either CPU or CUDA device.
* For CUDA tensors, the value is copied from device to host memory.
*
* @param tensor Pointer to a 0D boolean tensor (must not be null)
* @param ret_value Output pointer to store the boolean value (must not be null)
*
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or tensor dtype is not bool
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_item_bool(Tensor* tensor, bool* ret_value);

/**
* Creates a new tensor that shares the same underlying data as the source
* tensor.
*
* This function creates a new tensor view with the same shape, strides, and
* dtype as the source tensor, sharing the same underlying memory. The new
* tensor handle will be stored in ret_dst.
*
* @param src The source tensor providing the data and metadata.
* @param ret_dst On output, this will point to the new tensor view.
*
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or memory not tracked
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst);

// Function to clear all tensors from internal storage
AOTI_SHIM_EXPORT void clear_all_tensors();
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/shims/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ set(CUDA_SHIM_TESTS
test_aoti_torch__reinterpret_tensor
test_aoti_torch_copy_
test_aoti_torch_new_tensor_handle
test_aoti_torch_item_bool
test_aoti_torch_assign_tensors_out
)

enable_testing()
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_cpp_unittest("aoti_torch_item_bool")
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")
Loading
Loading