diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 675a9864e74..3fc414fb669 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu(); AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index ecb1ded2f39..86f6cdd6396 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -24,6 +24,7 @@ namespace executorch::backends::cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; +using executorch::backends::aoti::aoti_torch_dtype_bool; using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; @@ -800,6 +801,126 @@ AOTITorchError aoti_torch_new_tensor_handle( return Error::Ok; } + +AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_value != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: ret_value is null"); + + // Validate that tensor dtype is bool + int32_t dtype; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype)); + + ET_CHECK_OR_RETURN_ERROR( + dtype == aoti_torch_dtype_bool(), + InvalidArgument, + "aoti_torch_item_bool failed: tensor dtype is not bool (got %d)", + dtype); + + // Get the data pointer + const void* data_ptr = tensor->const_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: tensor data pointer is null"); + + // Check if tensor is on CUDA or CPU + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeDevice) { + // CUDA memory case: copy from device to host + bool device_value; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + &device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost)); + *ret_value = device_value; + } else { + // CPU memory case: direct access + const bool* bool_ptr = static_cast(data_ptr); + *ret_value = *bool_ptr; + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, + InvalidArgument, + "aoti_torch_assign_tensors_out failed: src is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_dst != nullptr, + InvalidArgument, + "aoti_torch_assign_tensors_out failed: ret_dst is null"); + + // Get the data pointer from the source tensor + void* data_ptr = src->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Get dtype from source tensor + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype)); + + // Get sizes and strides from source tensor + int64_t* sizes_ptr; + int64_t* strides_ptr; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr)); + + int64_t ndim = src->dim(); + + // Convert to vectors + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor view that shares the same memory as source tensor + std::shared_ptr tensor = make_tensor( + sizes, + data_ptr, // Share the same memory from source tensor + {}, // dim_order (empty, will be auto-generated) + strides, + dtype_to_scalar_type(dtype)); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "Failed to create tensor view in aoti_torch_assign_tensors_out"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_dst = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 935df853748..34b781a5270 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); * @return Error::Ok on success, appropriate error code on failure: * - Error::InvalidArgument: null pointers or invalid parameters */ -AOTITorchError aoti_torch_new_tensor_handle( - Tensor* orig_handle, - Tensor** new_handle); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle); + +/** + * Retrieves a boolean value from a 0D boolean tensor. + * + * This function extracts the scalar boolean value from a tensor that contains + * a single boolean element. The tensor can be on either CPU or CUDA device. + * For CUDA tensors, the value is copied from device to host memory. + * + * @param tensor Pointer to a 0D boolean tensor (must not be null) + * @param ret_value Output pointer to store the boolean value (must not be null) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or tensor dtype is not bool + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_item_bool(Tensor* tensor, bool* ret_value); + +/** + * Creates a new tensor that shares the same underlying data as the source + * tensor. + * + * This function creates a new tensor view with the same shape, strides, and + * dtype as the source tensor, sharing the same underlying memory. The new + * tensor handle will be stored in ret_dst. + * + * @param src The source tensor providing the data and metadata. + * @param ret_dst On output, this will point to the new tensor view. + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or memory not tracked + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst); // Function to clear all tensors from internal storage AOTI_SHIM_EXPORT void clear_all_tensors(); diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index a7df6075c37..204c08688c4 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -37,9 +37,14 @@ find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX}) # List of test files set(CUDA_SHIM_TESTS - test_aoti_torch_create_tensor_from_blob_v2 test_aoti_torch_empty_strided - test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor - test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle + test_aoti_torch_create_tensor_from_blob_v2 + test_aoti_torch_empty_strided + test_aoti_torch_delete_tensor_object + test_aoti_torch__reinterpret_tensor + test_aoti_torch_copy_ + test_aoti_torch_new_tensor_handle + test_aoti_torch_item_bool + test_aoti_torch_assign_tensors_out ) enable_testing() diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index b274ecf3675..7736624c02a 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -35,3 +35,5 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") + cuda_shim_cpp_unittest("aoti_torch_item_bool") + cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp new file mode 100644 index 00000000000..d5e1bcb2547 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp @@ -0,0 +1,245 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_assign_tensors_out tests +class AOTITorchAssignTensorsOutTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a test tensor + Tensor* create_test_tensor( + const std::vector& sizes, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA)) { + std::vector strides; + // Calculate contiguous strides + if (!sizes.empty()) { + strides.resize(sizes.size()); + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + + Tensor* tensor; + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + 0, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic functionality +TEST_F(AOTITorchAssignTensorsOutTest, BasicFunctionality) { + // Create a source tensor + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + // Create output tensor handle + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Verify the output tensor has the same properties as source + EXPECT_EQ(dst->dim(), src->dim()); + EXPECT_EQ(dst->size(0), src->size(0)); + EXPECT_EQ(dst->size(1), src->size(1)); + EXPECT_EQ(dst->numel(), src->numel()); + + // Verify they share the same memory + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with 1D tensor +TEST_F(AOTITorchAssignTensorsOutTest, OneDimensionalTensor) { + std::vector sizes = {10}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 1); + EXPECT_EQ(dst->size(0), 10); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with 3D tensor +TEST_F(AOTITorchAssignTensorsOutTest, ThreeDimensionalTensor) { + std::vector sizes = {2, 3, 4}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 3); + EXPECT_EQ(dst->size(0), 2); + EXPECT_EQ(dst->size(1), 3); + EXPECT_EQ(dst->size(2), 4); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with scalar (0D) tensor +TEST_F(AOTITorchAssignTensorsOutTest, ScalarTensor) { + std::vector sizes = {}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 0); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with null source pointer +TEST_F(AOTITorchAssignTensorsOutTest, NullSourcePointer) { + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(nullptr, &dst); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with null destination pointer +TEST_F(AOTITorchAssignTensorsOutTest, NullDestinationPointer) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + AOTITorchError error = aoti_torch_assign_tensors_out(src, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test that strides are preserved +TEST_F(AOTITorchAssignTensorsOutTest, StridesPreserved) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Get strides from both tensors + int64_t* src_strides; + int64_t* dst_strides; + aoti_torch_get_strides(src, &src_strides); + aoti_torch_get_strides(dst, &dst_strides); + + // Verify strides match + for (int64_t i = 0; i < src->dim(); i++) { + EXPECT_EQ(src_strides[i], dst_strides[i]); + } +} + +// Test with CPU tensor +TEST_F(AOTITorchAssignTensorsOutTest, CPUTensor) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor( + sizes, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test dtype is preserved +TEST_F(AOTITorchAssignTensorsOutTest, DtypePreserved) { + // Test with different dtypes + std::vector dtypes = { + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::INT32), + static_cast(SupportedDTypes::INT64), + }; + + for (int32_t dtype : dtypes) { + cleanup_tensor_metadata(); + clear_all_tensors(); + + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes, dtype); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Verify dtype is preserved + int32_t src_dtype, dst_dtype; + aoti_torch_get_dtype(src, &src_dtype); + aoti_torch_get_dtype(dst, &dst_dtype); + EXPECT_EQ(src_dtype, dst_dtype) + << "Dtype mismatch for dtype code: " << dtype; + } +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp new file mode 100644 index 00000000000..8e6bcbbfad6 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_item_bool tests +class AOTITorchItemBoolTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a bool tensor on CUDA with a specific value + Tensor* create_cuda_bool_tensor(bool value) { + // Create a 0D (scalar) bool tensor + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::BOOL), + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Set the value + bool host_value = value; + cudaError_t cuda_err = cudaMemcpy( + tensor->mutable_data_ptr(), + &host_value, + sizeof(bool), + cudaMemcpyHostToDevice); + + if (cuda_err != cudaSuccess) { + aoti_torch_delete_tensor_object(tensor); + return nullptr; + } + + return tensor; + } + + // Helper to create a bool tensor on CPU with a specific value + Tensor* create_cpu_bool_tensor(bool value) { + // Create a 0D (scalar) bool tensor + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::BOOL), + static_cast(SupportedDevices::CPU), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Set the value directly + bool* data_ptr = static_cast(tensor->mutable_data_ptr()); + *data_ptr = value; + + return tensor; + } +}; + +// Test extracting true value from CUDA bool tensor +TEST_F(AOTITorchItemBoolTest, CUDATensorTrueValue) { + Tensor* tensor = create_cuda_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_TRUE(result); +} + +// Test extracting false value from CUDA bool tensor +TEST_F(AOTITorchItemBoolTest, CUDATensorFalseValue) { + Tensor* tensor = create_cuda_bool_tensor(false); + ASSERT_NE(tensor, nullptr); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_FALSE(result); +} + +// Test extracting true value from CPU bool tensor +TEST_F(AOTITorchItemBoolTest, CPUTensorTrueValue) { + Tensor* tensor = create_cpu_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_TRUE(result); +} + +// Test extracting false value from CPU bool tensor +TEST_F(AOTITorchItemBoolTest, CPUTensorFalseValue) { + Tensor* tensor = create_cpu_bool_tensor(false); + ASSERT_NE(tensor, nullptr); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_FALSE(result); +} + +// Test with null tensor pointer +TEST_F(AOTITorchItemBoolTest, NullTensorPointer) { + bool result; + AOTITorchError error = aoti_torch_item_bool(nullptr, &result); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with null result pointer +TEST_F(AOTITorchItemBoolTest, NullResultPointer) { + Tensor* tensor = create_cuda_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + AOTITorchError error = aoti_torch_item_bool(tensor, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with non-bool dtype (should fail) +TEST_F(AOTITorchItemBoolTest, NonBoolDtype) { + // Create a float tensor + std::vector sizes = {}; + std::vector strides = {}; + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), // Not bool + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + bool result; + error = aoti_torch_item_bool(tensor, &result); + EXPECT_EQ(error, Error::InvalidArgument); +}