diff --git a/backends/aoti/slim/CMakeLists.txt b/backends/aoti/slim/CMakeLists.txt new file mode 100644 index 00000000000..b14d47f15c8 --- /dev/null +++ b/backends/aoti/slim/CMakeLists.txt @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +# SlimTensor library for ExecuTorch CUDA backend A lightweight tensor +# implementation for AOTI (Ahead-of-Time Inference) + +# C10 core headers +set(SLIM_C10_HEADERS + c10/core/Device.h c10/core/DeviceType.h c10/Contiguity.h c10/MemoryFormat.h + c10/SizesAndStrides.h c10/WrapDimMinimal.h +) + +# Utility headers +set(SLIM_UTIL_HEADERS util/SharedPtr.h util/SizeUtil.h util/type_convert.h) + +# Core SlimTensor headers +set(SLIM_CORE_HEADERS core/SlimTensor.h core/SlimTensorResize-incl.h + core/SlimTensorView-incl.h core/Storage.h +) + +# Factory headers +set(SLIM_FACTORY_HEADERS factory/Empty.h factory/Factory.h factory/FromBlob.h + factory/FromScalar.h factory/Pad.h +) + +# CUDA headers +set(SLIM_CUDA_HEADERS cuda/Exception.h cuda/Guard.h) + +# All headers combined +set(SLIM_TENSOR_HEADERS + ${SLIM_C10_HEADERS} ${SLIM_UTIL_HEADERS} ${SLIM_CORE_HEADERS} + ${SLIM_FACTORY_HEADERS} ${SLIM_CUDA_HEADERS} +) + +# Header-only interface library for SlimTensor +add_library(slim_tensor INTERFACE) +target_include_directories( + slim_tensor INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../../../.. +) + +# Link to ExecuTorch dependencies +target_link_libraries( + slim_tensor INTERFACE executorch_core extension_data_loader +) + +# CUDA support (if available) +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(slim_tensor INTERFACE CUDA::cudart) +endif() + +# Installation +install(FILES ${SLIM_C10_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/c10/core +) +install(FILES c10/Contiguity.h c10/MemoryFormat.h c10/SizesAndStrides.h + c10/WrapDimMinimal.h + DESTINATION include/executorch/backends/aoti/slim/c10 +) +install(FILES ${SLIM_UTIL_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/util +) +install(FILES ${SLIM_CORE_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/core +) +install(FILES ${SLIM_FACTORY_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/factory +) +install(FILES ${SLIM_CUDA_HEADERS} + DESTINATION include/executorch/backends/aoti/slim/cuda +) + +# Tests (if building tests) +if(EXECUTORCH_BUILD_TESTS) + enable_testing() + + # Basic SlimTensor tests + add_executable(test_slim_tensor_basic tests/test_slim_tensor_basic.cpp) + target_link_libraries( + test_slim_tensor_basic PRIVATE slim_tensor gtest gtest_main + ) + add_test(NAME test_slim_tensor_basic COMMAND test_slim_tensor_basic) + + # Type conversion tests + add_executable(test_type_convert tests/test_type_convert.cpp) + target_link_libraries(test_type_convert PRIVATE slim_tensor gtest gtest_main) + add_test(NAME test_type_convert COMMAND test_type_convert) + + # CUDA tests (if CUDA is enabled) + if(EXECUTORCH_BUILD_CUDA) + add_executable(test_slim_tensor_cuda tests/test_slim_tensor_cuda.cpp) + target_link_libraries( + test_slim_tensor_cuda PRIVATE slim_tensor gtest gtest_main CUDA::cudart + ) + add_test(NAME test_slim_tensor_cuda COMMAND test_slim_tensor_cuda) + endif() +endif() diff --git a/backends/aoti/slim/TARGETS b/backends/aoti/slim/TARGETS new file mode 100644 index 00000000000..0a42614a385 --- /dev/null +++ b/backends/aoti/slim/TARGETS @@ -0,0 +1,5 @@ +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/c10/TARGETS b/backends/aoti/slim/c10/TARGETS new file mode 100644 index 00000000000..0a42614a385 --- /dev/null +++ b/backends/aoti/slim/c10/TARGETS @@ -0,0 +1,5 @@ +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/c10/WrapDimMinimal.h b/backends/aoti/slim/c10/WrapDimMinimal.h new file mode 100644 index 00000000000..d0b51ff762b --- /dev/null +++ b/backends/aoti/slim/c10/WrapDimMinimal.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +// Different from the original implementation in c10, we don't need +// to support SymInt here. +namespace c10 { +namespace detail { +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); +} + +template +T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { + // Inline the fast paths + if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { + // For SymInts, we want an explicit control flow to trigger a guard, so we + // may as well branch too. + if (dim < 0) { + return dim + dim_post_expr; + } + return dim; + } + // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) + return c10::detail::maybe_wrap_dim_slow( + std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +inline int64_t +maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar = true) { + return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); +} + +namespace detail { +// This template can only be specialized at int64_t and c10::SymInt; +// you'll get linker errors otherwise +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { + ET_CHECK_MSG( + dim_post_expr >= 0, + "Rank cannot be negative but got %lld", + static_cast(dim_post_expr)); + + if (dim_post_expr == 0) { + ET_CHECK_MSG( + wrap_scalar, + "Dimension specified as %lld but tensor has no dimensions", + static_cast(dim)); + return c10::maybe_wrap_dim( + std::move(dim), /*dim_post_expr=*/1, /*wrap_scalar=*/false); + } + + T min = dim_post_expr * -1; + T max = dim_post_expr - 1; + ET_CHECK_MSG( + min <= dim && dim <= max, + "Dimension out of range (expected to be in range of [%lld" + ", %lld], but got %lld)", + static_cast(min), + static_cast(max), + static_cast(dim)); + + ET_DCHECK_MSG( + false, "should never reach here as dim should be out-of-bounds"); + return dim; // unreachable, but needed to suppress compiler warnings +} +} // namespace detail +} // namespace c10 diff --git a/backends/aoti/slim/c10/core/Contiguity.h b/backends/aoti/slim/c10/core/Contiguity.h new file mode 100644 index 00000000000..d5ff49561ab --- /dev/null +++ b/backends/aoti/slim/c10/core/Contiguity.h @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include +#include + +namespace standalone::c10 { + +template +bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + if (numel == 0) { + return true; + } + + T expected_stride = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (size_d == 1) { + continue; + } + + if (strides[d] != expected_stride) { + return false; + } + expected_stride *= size_d; + } + return true; +} + +// This function will return True if the tensor is contiguous, and False if the +// its not or if we can't determine if it is contiguous due to unbacked symbols +// (it could be either in that case based on the actual runtime data). +template +bool definitely_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + if (numel == 0) { + return true; + } + + T expected_stride = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (size_d == 1) { + continue; + } + + if (strides[d] != expected_stride) { + return false; + } + expected_stride *= size_d; + } + return true; +} + +template +bool _compute_channels_last_contiguous_2d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 4: { + T expected = 1; + for (auto& d : {1, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (size_d != 1) { + if (strides[d] != expected) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +bool _compute_channels_last_contiguous_3d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 5: { + T expected = 1; + for (auto& d : {1, 4, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (size_d != 1) { + if (strides[d] != expected) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +bool _compute_non_overlapping_and_dense( + ArrayRef sizes, + ArrayRef strides) { + auto dim = sizes.size(); + if (dim == 1) { + return sizes[0] < 2 || strides[0] == 1; + } + std::vector perm(dim); + for (const auto i : irange(dim)) { + perm[i] = i; + } + // Sort by strides, leaving 0 and 1 sized dims at the end of the array + std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { + if (sizes[a] < 2) { + return false; + } else if (sizes[b] < 2) { + return true; + } + return strides[a] < strides[b]; + }); + T require_stride = 1; + for (const auto i : irange(dim)) { + const auto& size_perm_i = sizes[perm[i]]; + if (size_perm_i < 2) { + return true; + } + if (strides[perm[i]] != require_stride) { + return false; + } + require_stride *= size_perm_i; + } + return true; +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/Device.h b/backends/aoti/slim/c10/core/Device.h new file mode 100644 index 00000000000..a9a6d3a8136 --- /dev/null +++ b/backends/aoti/slim/c10/core/Device.h @@ -0,0 +1,372 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Copied from c10/core/DeviceType.h with some modifications + +namespace standalone::c10 { +namespace detail { +enum class DeviceStringParsingState { + kSTART, + kINDEX_START, + kINDEX_REST, + kERROR +}; + +inline DeviceType parse_type(const std::string& device_string) { + static const std::array< + std::pair, + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)> + types = {{ + {"cpu", DeviceType::CPU}, + {"cuda", DeviceType::CUDA}, + {"ipu", DeviceType::IPU}, + {"xpu", DeviceType::XPU}, + {"mkldnn", DeviceType::MKLDNN}, + {"opengl", DeviceType::OPENGL}, + {"opencl", DeviceType::OPENCL}, + {"ideep", DeviceType::IDEEP}, + {"hip", DeviceType::HIP}, + {"ve", DeviceType::VE}, + {"fpga", DeviceType::FPGA}, + {"maia", DeviceType::MAIA}, + {"xla", DeviceType::XLA}, + {"lazy", DeviceType::Lazy}, + {"vulkan", DeviceType::Vulkan}, + {"mps", DeviceType::MPS}, + {"meta", DeviceType::Meta}, + {"hpu", DeviceType::HPU}, + {"mtia", DeviceType::MTIA}, + {"privateuseone", DeviceType::PrivateUse1}, + }}; + auto device = std::find_if( + types.begin(), + types.end(), + [&device_string](const std::pair& p) { + return p.first && p.first == device_string; + }); + if (device != types.end()) { + return device->second; + } + if (device_string == get_privateuse1_backend()) { + return DeviceType::PrivateUse1; + } + std::vector device_names; + for (const auto& it : types) { + if (it.first) { + device_names.push_back(it.first); + } + } + STANDALONE_CHECK( + false, + "Expected one of ", + Join(", ", device_names), + " device type at start of device string: ", + device_string); +} +} // namespace detail + +/// An index representing a specific device; e.g., the 1 in GPU 1. +/// A DeviceIndex is not independently meaningful without knowing +/// the DeviceType it is associated; try to use Device rather than +/// DeviceIndex directly. +using DeviceIndex = int8_t; + +/// Represents a compute device on which a tensor is located. A device is +/// uniquely identified by a type, which specifies the type of machine it is +/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the +/// specific compute device when there is more than one of a certain type. The +/// device index is optional, and in its defaulted state represents (abstractly) +/// "the current device". Further, there are two constraints on the value of the +/// device index, if one is explicitly stored: +/// 1. A negative index represents the current device, a non-negative index +/// represents a specific, concrete device, +/// 2. When the device type is CPU, the device index must be zero. +struct Device final { + using Type = DeviceType; + + /// Constructs a new `Device` from a `DeviceType` and an optional device + /// index. + /* implicit */ + Device(DeviceType type, DeviceIndex index = -1) : type_(type), index_(index) { + validate(); + } + + /// Constructs a `Device` from a string description, for convenience. + /// The string supplied must follow the following schema: + /// `(cpu|cuda)[:]` + /// where `cpu` or `cuda` specifies the device type, and + /// `:` optionally specifies a device index. + /* implicit */ Device(const std::string& device_string) : Device(Type::CPU) { + STANDALONE_CHECK(!device_string.empty(), "Device string must not be empty"); + + std::string device_name, device_index_str; + detail::DeviceStringParsingState pstate = + detail::DeviceStringParsingState::kSTART; + + // The code below tries to match the string in the variable + // device_string against the regular expression: + // ([a-zA-Z_]+)(?::([1-9]\\d*|0))? + for (size_t i = 0; pstate != detail::DeviceStringParsingState::kERROR && + i < device_string.size(); + ++i) { + const char ch = device_string.at(i); + const unsigned char uch = static_cast(ch); + switch (pstate) { + case detail::DeviceStringParsingState::kSTART: + if (ch != ':') { + if (std::isalpha(uch) || ch == '_') { + device_name.push_back(ch); + } else { + pstate = detail::DeviceStringParsingState::kERROR; + } + } else { + pstate = detail::DeviceStringParsingState::kINDEX_START; + } + break; + + case detail::DeviceStringParsingState::kINDEX_START: + if (std::isdigit(uch)) { + device_index_str.push_back(ch); + pstate = detail::DeviceStringParsingState::kINDEX_REST; + } else { + pstate = detail::DeviceStringParsingState::kERROR; + } + break; + + case detail::DeviceStringParsingState::kINDEX_REST: + if (device_index_str.at(0) == '0') { + pstate = detail::DeviceStringParsingState::kERROR; + break; + } + if (std::isdigit(uch)) { + device_index_str.push_back(ch); + } else { + pstate = detail::DeviceStringParsingState::kERROR; + } + break; + + case detail::DeviceStringParsingState::kERROR: + // Execution won't reach here. + break; + } + } + + const bool has_error = device_name.empty() || + pstate == detail::DeviceStringParsingState::kERROR || + (pstate == detail::DeviceStringParsingState::kINDEX_START && + device_index_str.empty()); + + STANDALONE_CHECK( + !has_error, "Invalid device string: '", device_string, "'"); + + try { + if (!device_index_str.empty()) { + index_ = static_cast(std::stoi(device_index_str)); + } + } catch (const std::exception&) { + STANDALONE_CHECK( + false, + "Could not parse device index '", + device_index_str, + "' in device string '", + device_string, + "'"); + } + type_ = detail::parse_type(device_name); + validate(); + } + + /// Returns true if the type and index of this `Device` matches that of + /// `other`. + bool operator==(const Device& other) const noexcept { + return this->type_ == other.type_ && this->index_ == other.index_; + } + + /// Returns true if the type or index of this `Device` differs from that of + /// `other`. + bool operator!=(const Device& other) const noexcept { + return !(*this == other); + } + + /// Sets the device index. + void set_index(DeviceIndex index) { + index_ = index; + } + + /// Returns the type of device this is. + DeviceType type() const noexcept { + return type_; + } + + /// Returns the optional index. + DeviceIndex index() const noexcept { + return index_; + } + + /// Returns true if the device has a non-default index. + bool has_index() const noexcept { + return index_ != -1; + } + + /// Return true if the device is of CUDA type. + bool is_cuda() const noexcept { + return type_ == DeviceType::CUDA; + } + + /// Return true if the device is of PrivateUse1 type. + bool is_privateuseone() const noexcept { + return type_ == DeviceType::PrivateUse1; + } + + /// Return true if the device is of MPS type. + bool is_mps() const noexcept { + return type_ == DeviceType::MPS; + } + + /// Return true if the device is of HIP type. + bool is_hip() const noexcept { + return type_ == DeviceType::HIP; + } + + /// Return true if the device is of VE type. + bool is_ve() const noexcept { + return type_ == DeviceType::VE; + } + + /// Return true if the device is of XPU type. + bool is_xpu() const noexcept { + return type_ == DeviceType::XPU; + } + + /// Return true if the device is of IPU type. + bool is_ipu() const noexcept { + return type_ == DeviceType::IPU; + } + + /// Return true if the device is of XLA type. + bool is_xla() const noexcept { + return type_ == DeviceType::XLA; + } + + /// Return true if the device is of MTIA type. + bool is_mtia() const noexcept { + return type_ == DeviceType::MTIA; + } + + /// Return true if the device is of HPU type. + bool is_hpu() const noexcept { + return type_ == DeviceType::HPU; + } + + /// Return true if the device is of Lazy type. + bool is_lazy() const noexcept { + return type_ == DeviceType::Lazy; + } + + /// Return true if the device is of Vulkan type. + bool is_vulkan() const noexcept { + return type_ == DeviceType::Vulkan; + } + + /// Return true if the device is of Metal type. + bool is_metal() const noexcept { + return type_ == DeviceType::Metal; + } + + /// Return true if the device is of MAIA type. + bool is_maia() const noexcept { + return type_ == DeviceType::MAIA; + } + + /// Return true if the device is of META type. + bool is_meta() const noexcept { + return type_ == DeviceType::Meta; + } + + /// Return true if the device is of CPU type. + bool is_cpu() const noexcept { + return type_ == DeviceType::CPU; + } + + /// Return true if the device supports arbitrary strides. + bool supports_as_strided() const noexcept { + return type_ != DeviceType::IPU && type_ != DeviceType::XLA && + type_ != DeviceType::Lazy && type_ != DeviceType::MTIA; + } + + /// Same string as returned from operator<<. + std::string str() const { + std::string str = DeviceTypeName(type(), /* lower case */ true); + if (has_index()) { + str.push_back(':'); + str.append(std::to_string(index())); + } + return str; + } + + private: + DeviceType type_; + DeviceIndex index_ = -1; + void validate() { + // Removing these checks in release builds noticeably improves + // performance in micro-benchmarks. + // This is safe to do, because backends that use the DeviceIndex + // have a later check when we actually try to switch to that device. + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( + index_ >= -1, + "Device index must be -1 or non-negative, got ", + static_cast(index_)); + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( + !is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got ", + static_cast(index_)); + } +}; + +inline std::ostream& operator<<(std::ostream& stream, const Device& device) { + stream << device.str(); + return stream; +} +} // namespace standalone::c10 + +namespace std { +template <> +struct hash { + size_t operator()(standalone::c10::Device d) const noexcept { + // Are you here because this static assert failed? Make sure you ensure + // that the bitmasking code below is updated accordingly! + static_assert( + sizeof(standalone::c10::DeviceType) == 1, "DeviceType is not 8-bit"); + static_assert( + sizeof(standalone::c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); + // Note [Hazard when concatenating signed integers] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // We must first convert to a same-sized unsigned type, before promoting to + // the result type, to prevent sign extension when any of the values is -1. + // If sign extension occurs, you'll clobber all of the values in the MSB + // half of the resulting integer. + // + // Technically, by C/C++ integer promotion rules, we only need one of the + // uint32_t casts to the result type, but we put in both for explicitness's + // sake. + uint32_t bits = static_cast(static_cast(d.type())) + << 16 | + static_cast(static_cast(d.index())); + return std::hash{}(bits); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/DeviceType.h b/backends/aoti/slim/c10/core/DeviceType.h new file mode 100644 index 00000000000..f2631a48f2d --- /dev/null +++ b/backends/aoti/slim/c10/core/DeviceType.h @@ -0,0 +1,134 @@ +#pragma once + +// Copied from c10/core/DeviceType.h with some modifications: +// * enum values are kept the same as c10 and guarded by device_type_test +// * Make the implementaion header-only +// * Simplify some implementation +// * Disable PrivateUse1 name registration + +#include +#include +#include +#include +#include +#include + +#include + +namespace standalone::c10 { +enum class DeviceType : int8_t { + CPU = 0, + CUDA = 1, // CUDA. + MKLDNN = 2, // Reserved for explicit MKLDNN + OPENGL = 3, // OpenGL + OPENCL = 4, // OpenCL + IDEEP = 5, // IDEEP. + HIP = 6, // AMD HIP + FPGA = 7, // FPGA + MAIA = 8, // ONNX Runtime / Microsoft + XLA = 9, // XLA / TPU + Vulkan = 10, // Vulkan + Metal = 11, // Metal + XPU = 12, // XPU + MPS = 13, // MPS + Meta = 14, // Meta (tensors with no data) + HPU = 15, // HPU / HABANA + VE = 16, // SX-Aurora / NEC + Lazy = 17, // Lazy Tensors + IPU = 18, // Graphcore IPU + MTIA = 19, // Meta training and inference devices + PrivateUse1 = 20, // PrivateUse1 device + // NB: If you add more devices: + // - Change the implementations of DeviceTypeName and isValidDeviceType + // - Change the number below + COMPILE_TIME_MAX_DEVICE_TYPES = 21, +}; + +constexpr DeviceType kCPU = DeviceType::CPU; +constexpr DeviceType kCUDA = DeviceType::CUDA; +constexpr DeviceType kMKLDNN = DeviceType::MKLDNN; +constexpr DeviceType kOPENGL = DeviceType::OPENGL; +constexpr DeviceType kOPENCL = DeviceType::OPENCL; +constexpr DeviceType kIDEEP = DeviceType::IDEEP; +constexpr DeviceType kHIP = DeviceType::HIP; +constexpr DeviceType kFPGA = DeviceType::FPGA; +constexpr DeviceType kMAIA = DeviceType::MAIA; +constexpr DeviceType kXLA = DeviceType::XLA; +constexpr DeviceType kVulkan = DeviceType::Vulkan; +constexpr DeviceType kMetal = DeviceType::Metal; +constexpr DeviceType kXPU = DeviceType::XPU; +constexpr DeviceType kMPS = DeviceType::MPS; +constexpr DeviceType kMeta = DeviceType::Meta; +constexpr DeviceType kHPU = DeviceType::HPU; +constexpr DeviceType kVE = DeviceType::VE; +constexpr DeviceType kLazy = DeviceType::Lazy; +constexpr DeviceType kIPU = DeviceType::IPU; +constexpr DeviceType kMTIA = DeviceType::MTIA; +constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1; + +// define explicit int constant +constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); + +static_assert( + COMPILE_TIME_MAX_DEVICE_TYPES <= 21, + "Hey! You seem to be adding a lot of new DeviceTypes. The intent was " + "for this constant to reflect the actual number of DeviceTypes we support " + "in PyTorch; it's important that this number is not too large as we " + "use this to allocate stack arrays in some places in our code. If you " + "are indeed just adding the 20th device type, feel free to change " + "the check to 32; but if you are adding some sort of extensible device " + "types registration, please be aware that you are affecting code that " + "this number is small. Try auditing uses of this constant."); + +// Doesn't support PrivateUse1 name registration in standalone +inline std::string get_privateuse1_backend(bool lower_case = true) { + return lower_case ? "privateuse1" : "PrivateUse1"; +} + +inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) { + static const std::string device_names[] = { + "CPU", "CUDA", "MKLDNN", "OPENGL", "OPENCL", "IDEEP", "HIP", + "FPGA", "MAIA", "XLA", "VULKAN", "METAL", "XPU", "MPS", + "META", "HPU", "VE", "LAZY", "IPU", "MTIA"}; + + int idx = static_cast(d); + if (idx < 0 || idx >= COMPILE_TIME_MAX_DEVICE_TYPES) { + STANDALONE_CHECK(false, "Unknown device: ", static_cast(d)); + } + if (d == DeviceType::PrivateUse1) { + return get_privateuse1_backend(lower_case); + } + std::string name = device_names[idx]; + if (lower_case) { + std::transform(name.begin(), name.end(), name.begin(), ::tolower); + } + return name; +} + +// NB: Per the C++ standard (e.g., +// https://stackoverflow.com/questions/18195312/what-happens-if-you-static-cast-invalid-value-to-enum-class) +// as long as you cast from the same underlying type, it is always valid to cast +// into an enum class (even if the value would be invalid by the enum.) Thus, +// the caller is allowed to cast a possibly invalid int16_t to DeviceType and +// then pass it to this function. (I considered making this function take an +// int16_t directly, but that just seemed weird.) +inline bool isValidDeviceType(DeviceType d) { + int idx = static_cast(d); + return idx >= 0 && idx < COMPILE_TIME_MAX_DEVICE_TYPES; +} + +inline std::ostream& operator<<(std::ostream& stream, DeviceType type) { + stream << DeviceTypeName(type, /* lower case */ true); + return stream; +} +} // namespace standalone::c10 + +namespace std { +template <> +struct hash { + std::size_t operator()(standalone::c10::DeviceType k) const { + return std::hash()(static_cast(k)); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/Layout.h b/backends/aoti/slim/c10/core/Layout.h new file mode 100644 index 00000000000..79230f23bb7 --- /dev/null +++ b/backends/aoti/slim/c10/core/Layout.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#include +#include + +namespace standalone::c10 { +enum class Layout : int8_t { + Strided, + Sparse, + SparseCsr, + Mkldnn, + SparseCsc, + SparseBsr, + SparseBsc, + Jagged, + NumOptions +}; + +constexpr auto kStrided = Layout::Strided; +constexpr auto kSparse = Layout::Sparse; +constexpr auto kSparseCsr = Layout::SparseCsr; +constexpr auto kMkldnn = Layout::Mkldnn; +constexpr auto kSparseCsc = Layout::SparseCsc; +constexpr auto kSparseBsr = Layout::SparseBsr; +constexpr auto kSparseBsc = Layout::SparseBsc; +constexpr auto kJagged = Layout::Jagged; + +inline std::ostream& operator<<(std::ostream& stream, c10::Layout layout) { + switch (layout) { + case c10::kStrided: + return stream << "Strided"; + case c10::kSparse: + return stream << "Sparse"; + case c10::kSparseCsr: + return stream << "SparseCsr"; + case c10::kSparseCsc: + return stream << "SparseCsc"; + case c10::kSparseBsr: + return stream << "SparseBsr"; + case c10::kSparseBsc: + return stream << "SparseBsc"; + case c10::kMkldnn: + return stream << "Mkldnn"; + case c10::kJagged: + return stream << "Jagged"; + default: + STANDALONE_CHECK(false, "Unknown layout"); + } +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/MemoryFormat.h b/backends/aoti/slim/c10/core/MemoryFormat.h new file mode 100644 index 00000000000..756caf64f26 --- /dev/null +++ b/backends/aoti/slim/c10/core/MemoryFormat.h @@ -0,0 +1,291 @@ +#pragma once + +#include +#include + +#include +#include +#include + +// Memory format is not the property of a Tensor. It is the way to tell an +// operator how the result should be organized in memory and nothing more. That +// means memory format should never be used as return value for any tensor state +// interrogation functions (internally and externally). +// +// Possible options are: +// Preserve: +// If any of the input tensors is in channels_last format, operator output +// should be in channels_last format +// +// Contiguous: +// Regardless of input tensors format, the output should be contiguous +// Tensor. +// +// ChannelsLast: +// Regardless of input tensors format, the output should be in channels_last +// format. + +namespace standalone::c10 { +enum class MemoryFormat : int8_t { + Contiguous, + Preserve, + ChannelsLast, + ChannelsLast3d, + NumOptions +}; + +// If you are seeing this, it means that this call site was not checked if +// the memory format could be preserved, and it was switched to old default +// behaviour of contiguous +#define LEGACY_CONTIGUOUS_MEMORY_FORMAT \ + ::standalone::c10::get_contiguous_memory_format() + +inline MemoryFormat get_contiguous_memory_format() { + return MemoryFormat::Contiguous; +} + +inline std::ostream& operator<<( + std::ostream& stream, + MemoryFormat memory_format) { + switch (memory_format) { + case MemoryFormat::Preserve: + return stream << "Preserve"; + case MemoryFormat::Contiguous: + return stream << "Contiguous"; + case MemoryFormat::ChannelsLast: + return stream << "ChannelsLast"; + case MemoryFormat::ChannelsLast3d: + return stream << "ChannelsLast3d"; + default: + STANDALONE_CHECK(false, "Unknown memory format ", memory_format); + } +} + +// Note: Hardcoded the channel last stride indices here to get better +// performance +template +inline std::vector get_channels_last_strides_2d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 4: + strides[1] = 1; + strides[3] = sizes[1]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 3: + strides[0] = 1; + strides[2] = sizes[0]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + STANDALONE_INTERNAL_ASSERT( + false, "ChannelsLast2d doesn't support size ", sizes.size()); + } +} + +inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { + return get_channels_last_strides_2d(sizes); +} + +template +std::vector get_channels_last_strides_3d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 5: + strides[1] = 1; + strides[4] = sizes[1]; + strides[3] = strides[4] * sizes[4]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 4: + strides[0] = 1; + strides[3] = sizes[0]; + strides[2] = strides[3] * sizes[3]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + STANDALONE_INTERNAL_ASSERT( + false, "ChannelsLast3d doesn't support size ", sizes.size()); + } +} + +inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { + return get_channels_last_strides_3d(sizes); +} + +// NOTE: +// Below are Helper functions for is_channels_last_strides_xd. +// 1. Please do not combine these helper functions, each helper function handles +// exactly one case of sizes + memory_format, by doing this, the strides indices +// will be a constant array and we can access it using constant index number, +// the compiler will fully unroll the loop on strides indices to gain a better +// performance. +// 2. No error check in helper function, caller ensures the correctness of the +// input +// 3. All helper functions have similar comments, only 1st helper function is +// commented here. +template +inline bool is_channels_last_strides_2d_s4( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + // special case for trivial C dimension. default to NCHW + if (strides[1] == 0) { + return false; + } + // loop strides indices + for (auto& d : {1, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + // Fallback to NCHW as default layout for ambiguous cases + // This is the flaw of implicit memory_format from strides. + // N111 tensor with identical strides for size 1 dimension; + // Two cases could lead us here: + // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + // b. N11W contiguous Tensor sliced on the W-dimension. + // ([N,1,1,1]@[W,W,W,W]) + if (d == 0 && min == strides[1]) { + return false; + } + // This is necessary to: + // 1. distinguish the memory_format of N1H1; + // [H, 1, 1, 1] channels_last stride + // [H, H, 1, 1] contiguous stride + // 2. permutation of 1C1W: + // [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + // [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +template +inline bool is_channels_last_strides_3d_s5( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + if (strides[1] == 0) { + return false; + } + for (auto& d : {1, 4, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + if (d == 0 && min == strides[1]) { + return false; + } + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +// Note [Ambiguous is_channels_last_strides_xd] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// The flaw of carrying memory_format implicitly through strides is very hard +// to WAR properly. issue #24090 +// Without the history of permutation, we can't infer the memory_format of a +// tensor from the snapshot of its size & stride +// e.g. +// +// 1. We can NOT specify the memory_format of N111 tensor through strides in a +// meaningful way; +// +// 2. Two path that ended up with identical size/stride +// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W] +// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C] +// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer +// the memory_format of the original tensor. +// +// Due to the limitations, our temporary WAR `is_channels_last_strides` does the +// best effort to infer whether the original memory_format of a tensor is +// MemoryFormat::ChannelsLast. The two objectives of this function (ordered +// by their importance): +// 1. Ensure that normal shape manipulation does not accidentally change the +// MemoryFormat of an existing tensor. +// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors; +// +// The function does so via checking strides of the tensor, including strides of +// size-1 dimensions. Although conventionally PyTorch implies no restriction on +// trivial stride (stride for size-1 dimension). +// +// Note that this approach is a compromise. We did not solve the problem +// completely. Many cases we will not be able to infer the correct memory +// format. +// The implementation of `is_channels_last_strides` is to serve the objectives: +// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental +// conversion); Best effort to maintain the ChannelsLast flag. +// +// Due to the fact that this is not a bulletproof solution, through testing +// (aten/src/ATen/test/memory_format_test.cpp) +// a. we ensure that the common tasks are supported; +// a. we identify corner cases where the implementation compromises on. +// +// By the time accumulated permutation is enabled to replace implicit +// memory_format through strides, we should be updating our tests and fix the +// issues in our tests. +// +// We use Channels Last 2d as an example above. +// This is a general problem for all the is_channels_last_strides_xd +// implementation. Please check the helper functions +// (is_channels_last_strides_*d_s*) for more details. + +template +inline bool is_channels_last_strides_2d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 4: + return is_channels_last_strides_2d_s4(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +inline bool is_channels_last_strides_3d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 5: + return is_channels_last_strides_3d_s5(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +inline bool is_channels_last_strides_2d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_2d(sizes, strides); +} + +inline bool is_channels_last_strides_3d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_3d(sizes, strides); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/Scalar.h b/backends/aoti/slim/c10/core/Scalar.h new file mode 100644 index 00000000000..1c61ecb4704 --- /dev/null +++ b/backends/aoti/slim/c10/core/Scalar.h @@ -0,0 +1,360 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// Copy-pasted from c10/core/Scalar.h, but dropping SymScalar support + +namespace standalone::c10 { + +/** + * Scalar represents a 0-dimensional tensor which contains a single element. + * Unlike a tensor, numeric literals (in C++) are implicitly convertible to + * Scalar (which is why, for example, we provide both add(Tensor) and + * add(Scalar) overloads for many operations). It may also be used in + * circumstances where you statically know a tensor is 0-dim and single size, + * but don't know its type. + */ +class Scalar { + public: + Scalar() : Scalar(int64_t(0)) {} + +#define DEFINE_IMPLICIT_CTOR(type, name) \ + Scalar(type vv) : Scalar(vv, true) {} + + AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR) + AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) + AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR) + + // Helper constructors to allow Scalar creation from long and long long types + // As std::is_same_v is false(except Android), one needs to + // provide a constructor from either long or long long in addition to one from + // int64_t +#if defined(__APPLE__) || defined(__MACOSX) + static_assert( + std::is_same_v, + "int64_t is the same as long long on MacOS"); + Scalar(long vv) : Scalar(vv, true) {} +#endif +#if defined(_MSC_VER) + static_assert( + std::is_same_v, + "int64_t is the same as long long on Windows"); + Scalar(long vv) : Scalar(vv, true) {} +#endif +#if defined(__linux__) && !defined(__ANDROID__) + static_assert( + sizeof(void*) != 8 || std::is_same_v, + "int64_t is the same as long on 64 bit Linux"); +#if LONG_MAX != INT_MAX + Scalar(long long vv) : Scalar(vv, true) {} +#endif /* not 32-bit system */ +#endif + + Scalar(uint16_t vv) : Scalar(vv, true) {} + Scalar(uint32_t vv) : Scalar(vv, true) {} + Scalar(uint64_t vv) { + if (vv > static_cast(INT64_MAX)) { + tag = Tag::HAS_u; + v.u = vv; + } else { + tag = Tag::HAS_i; + // NB: no need to use convert, we've already tested convertibility + v.i = static_cast(vv); + } + } + +#undef DEFINE_IMPLICIT_CTOR + + // Value* is both implicitly convertible to SymbolicVariable and bool which + // causes ambiguity error. Specialized constructor for bool resolves this + // problem. + template < + typename T, + typename std::enable_if_t, bool>* = nullptr> + Scalar(T vv) : tag(Tag::HAS_b) { + v.i = convert(vv); + } + +#define DEFINE_ACCESSOR(type, name) \ + type to##name() const { \ + if (Tag::HAS_d == tag) { \ + return checked_convert(v.d, #type); \ + } else if (Tag::HAS_z == tag) { \ + return checked_convert>( \ + v.z, #type); \ + } \ + if (Tag::HAS_b == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_i == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_u == tag) { \ + return checked_convert(v.u, #type); \ + } \ + STANDALONE_CHECK(false) \ + } + + // TODO: Support ComplexHalf accessor + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR) + DEFINE_ACCESSOR(uint16_t, UInt16) + DEFINE_ACCESSOR(uint32_t, UInt32) + DEFINE_ACCESSOR(uint64_t, UInt64) + +#undef DEFINE_ACCESSOR + + // also support scalar.to(); + // Deleted for unsupported types, but specialized below for supported types + template + T to() const = delete; + + // audit uses of data_ptr + const void* data_ptr() const { + return static_cast(&v); + } + + bool isFloatingPoint() const { + return Tag::HAS_d == tag; + } + + bool isIntegral(bool includeBool) const { + return Tag::HAS_i == tag || Tag::HAS_u == tag || + (includeBool && isBoolean()); + } + + bool isComplex() const { + return Tag::HAS_z == tag; + } + bool isBoolean() const { + return Tag::HAS_b == tag; + } + + STANDALONE_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept { + if (&other == this) { + return *this; + } + + moveFrom(std::move(other)); + return *this; + } + + STANDALONE_ALWAYS_INLINE Scalar& operator=(const Scalar& other) { + if (&other == this) { + return *this; + } + + *this = Scalar(other); + return *this; + } + + Scalar operator-() const { + STANDALONE_CHECK( + !isBoolean(), + "torch boolean negative, the `-` operator, is not supported."); + if (isFloatingPoint()) { + return Scalar(-v.d); + } else if (isComplex()) { + return Scalar(-v.z); + } else if (isIntegral(false)) { + return Scalar(-v.i); + } + STANDALONE_INTERNAL_ASSERT( + false, "unknown ivalue tag ", static_cast(tag)); + } + + Scalar conj() const { + if (isComplex()) { + return Scalar(std::conj(v.z)); + } else { + return *this; + } + } + + Scalar log() const { + if (isComplex()) { + return std::log(v.z); + } else if (isFloatingPoint()) { + return std::log(v.d); + } else if (isIntegral(false)) { + return std::log(v.i); + } + STANDALONE_INTERNAL_ASSERT( + false, "unknown ivalue tag ", static_cast(tag)); + } + + template < + typename T, + typename std::enable_if_t::value, int> = + 0> + bool equal(T num) const { + if (isComplex()) { + auto val = v.z; + return (val.real() == num) && (val.imag() == T()); + } else if (isFloatingPoint()) { + return toDouble() == num; + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num; + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num; + } + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + return false; + } else { + STANDALONE_INTERNAL_ASSERT(false); + } + } + + template < + typename T, + typename std::enable_if_t::value, int> = 0> + bool equal(T num) const { + if (isComplex()) { + return v.z == num; + } else if (isFloatingPoint()) { + return (toDouble() == num.real()) && (num.imag() == T()); + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num.real() && num.imag() == T(); + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num.real() && num.imag() == T(); + } + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + return false; + } else { + STANDALONE_INTERNAL_ASSERT(false); + } + } + + bool equal(bool num) const { + if (isBoolean()) { + return static_cast(v.i) == num; + } else { + return false; + } + } + + standalone::c10::ScalarType type() const { + if (isComplex()) { + return standalone::c10::ScalarType::ComplexDouble; + } else if (isFloatingPoint()) { + return standalone::c10::ScalarType::Double; + } else if (isIntegral(/*includeBool=*/false)) { + // Represent all integers as long, UNLESS it is unsigned and therefore + // unrepresentable as long + if (Tag::HAS_u == tag) { + return standalone::c10::ScalarType::UInt64; + } + return standalone::c10::ScalarType::Long; + } else if (isBoolean()) { + return standalone::c10::ScalarType::Bool; + } else { + throw std::runtime_error("Unknown scalar type."); + } + } + + Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) { + moveFrom(std::move(rhs)); + } + + Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {} + + // We can't set v in the initializer list using the + // syntax v{ .member = ... } because it doesn't work on MSVC + private: + enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b }; + + // Note [Meaning of HAS_u] + // ~~~~~~~~~~~~~~~~~~~~~~~ + // HAS_u is a bit special. On its face, it just means that we + // are holding an unsigned integer. However, we generally don't + // distinguish between different bit sizes in Scalar (e.g., we represent + // float as double), instead, it represents a mathematical notion + // of some quantity (integral versus floating point). So actually, + // HAS_u is used solely to represent unsigned integers that could + // not be represented as a signed integer. That means only uint64_t + // potentially can get this tag; smaller types like uint8_t fits into a + // regular int and so for BC reasons we keep as an int. + + // NB: assumes that self has already been cleared + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + STANDALONE_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept { + v = rhs.v; + tag = rhs.tag; + } + + Tag tag; + + union v_t { + double d{}; + int64_t i; + // See Note [Meaning of HAS_u] + uint64_t u; + standalone::c10::complex z; + // NOLINTNEXTLINE(modernize-use-equals-default) + v_t() {} // default constructor + } v; + + template < + typename T, + typename std::enable_if_t< + std::is_integral_v && !std::is_same_v, + bool>* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_i) { + v.i = convert(vv); + } + + template < + typename T, + typename std::enable_if_t< + !std::is_integral_v && !standalone::c10::is_complex::value, + bool>* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_d) { + v.d = convert(vv); + } + + template < + typename T, + typename std::enable_if_t::value, bool>* = + nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_z) { + v.z = convert(vv); + } +}; + +// define the scalar.to() specializations +#define DEFINE_TO(T, name) \ + template <> \ + inline T Scalar::to() const { \ + return to##name(); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO) +DEFINE_TO(uint16_t, UInt16) +DEFINE_TO(uint32_t, UInt32) +DEFINE_TO(uint64_t, UInt64) +#undef DEFINE_TO + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h new file mode 100644 index 00000000000..6daeaad5f2c --- /dev/null +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -0,0 +1,735 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace standalone::c10 { + +// dummy struct for uint1 to uint7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_uint1_7_t {}; + +// dummy struct for int1 to int7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_int1_7_t {}; + +// For the macros below: +// +// For users: If you want to macro some code for all non-QInt scalar types +// (i.e. types with complete information, you probably want one of the +// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are +// designed to behave similarly to the Dispatch macros with the same name. +// +// For adding a new dtype: In the beginning, we had an idea that there was a +// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to +// iterate over them. But over the years we added weird types which couldn't +// be handled uniformly everywhere and so in the end we ended up with some +// mish-mosh of some helper macros, but mostly use sites making a call about +// what dtypes they can or can't support. So if you want to add a new dtype, +// the preferred resolution is to find a dtype similar to what you want, +// grep for it and edit all the sites you find this way. If you need to add +// a completely new kind of dtype, you're going to have to laboriously audit +// all of the sites everywhere to figure out how it should work. Consulting +// some old PRs where we added new dtypes (check history of this file) can +// help give you an idea where to start. + +// NB: Order matters for this macro; it is relied upon in +// _promoteTypesLookup and the serialization format. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(standalone::c10::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(standalone::c10::complex, ComplexHalf) /* 8 */ \ + _(standalone::c10::complex, ComplexFloat) /* 9 */ \ + _(standalone::c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(standalone::c10::qint8, QInt8) /* 12 */ \ + _(standalone::c10::quint8, QUInt8) /* 13 */ \ + _(standalone::c10::qint32, QInt32) /* 14 */ \ + _(standalone::c10::BFloat16, BFloat16) /* 15 */ \ + _(standalone::c10::quint4x2, QUInt4x2) /* 16 */ \ + _(standalone::c10::quint2x4, QUInt2x4) /* 17 */ \ + _(standalone::c10::bits1x8, Bits1x8) /* 18 */ \ + _(standalone::c10::bits2x4, Bits2x4) /* 19 */ \ + _(standalone::c10::bits4x2, Bits4x2) /* 20 */ \ + _(standalone::c10::bits8, Bits8) /* 21 */ \ + _(standalone::c10::bits16, Bits16) /* 22 */ \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(standalone::c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(standalone::c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ \ + _(standalone::c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ + _(standalone::c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ + _(standalone::c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ + _(standalone::c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ + _(standalone::c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ + _(standalone::c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ + _(standalone::c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ + _(standalone::c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ + _(standalone::c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ + _(standalone::c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ + _(standalone::c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ + _(standalone::c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ + _(standalone::c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ + _(standalone::c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ + _(standalone::c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ + _(standalone::c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ + +// If you want to support ComplexHalf for real, add ComplexHalf +// into this macro (and change the name). But beware: convert() +// doesn't work for all the conversions you need... +// +// TODO: To add unsigned int types here, we must define accumulate type. +// But uint8 currently accumulates into int64, so we would have to make +// an inconsistent choice for the larger types. Difficult. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(standalone::c10::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(standalone::c10::complex, ComplexFloat) \ + _(standalone::c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(standalone::c10::BFloat16, BFloat16) \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) + +// This macro controls many of our C++ APIs, including constructors +// for Scalar as well as the data() and item() accessors on Tensor +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(standalone::c10::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(standalone::c10::complex, ComplexHalf) \ + _(standalone::c10::complex, ComplexFloat) \ + _(standalone::c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(standalone::c10::BFloat16, BFloat16) \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) \ + _(standalone::c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(standalone::c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(standalone::c10::Float8_e8m0fnu, Float8_e8m0fnu) + +enum class ScalarType : int8_t { +#define DEFINE_ST_ENUM_VAL_(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ + Undefined, + NumOptions +}; + +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); + +namespace impl { + +// These are used to map ScalarTypes to C++ types. + +template +struct ScalarTypeToCPPType; + +#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ + template <> \ + struct ScalarTypeToCPPType { \ + using type = cpp_type; \ + \ + /* This is a workaround for the CUDA bug which prevents */ \ + /* ::detail::ScalarTypeToCType::type being used directly due to */ \ + /* ambiguous reference which can't to be resolved. For some reason it */ \ + /* can't pick between standalone::c10::detail and \ + * standalone::c10::cuda::detail. */ \ + /* For repro example, please see: */ \ + /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \ + /* TODO: remove once the bug is fixed. */ \ + static type t; \ + }; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) + +#undef SPECIALIZE_ScalarTypeToCPPType + +template +using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType::type; + +} // namespace impl + +template +struct CppTypeToScalarType; + +#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \ + template <> \ + struct CppTypeToScalarType \ + : std::integral_constant< \ + standalone::c10::ScalarType, \ + standalone::c10::ScalarType::scalar_type> {}; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) + +#undef SPECIALIZE_CppTypeToScalarType + +// NB: despite its generic sounding name, the macros that don't take _AND +// are mostly only used by tensorexpr +#define AT_FORALL_INT_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) + +#define AT_FORALL_SCALAR_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) + +// These macros are often controlling how many template instantiations we +// create for kernels. It is typically inappropriate to add new dtypes here, +// instead, new types should be added to use sites on a case-by-case basis. +// We generally are not accepting new dtypes due to binary size concerns. + +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE>::t), \ + SCALARTYPE) + +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) + +#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) + +#define AT_FORALL_SCALAR_TYPES_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE4>::t), \ + SCALARTYPE4) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE5>::t), \ + SCALARTYPE5) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE6>::t), \ + SCALARTYPE6) \ + _(decltype(standalone::c10::impl::ScalarTypeToCPPType< \ + standalone::c10::ScalarType::SCALARTYPE7>::t), \ + SCALARTYPE7) + +#define AT_FORALL_QINT_TYPES(_) \ + _(standalone::c10::qint8, QInt8) \ + _(standalone::c10::quint8, QUInt8) \ + _(standalone::c10::qint32, QInt32) \ + _(standalone::c10::quint4x2, QUInt4x2) \ + _(standalone::c10::quint2x4, QUInt2x4) + +#define AT_FORALL_FLOAT8_TYPES(_) \ + _(standalone::c10::Float8_e5m2, Float8_e5m2) \ + _(standalone::c10::Float8_e4m3fn, Float8_e4m3fn) \ + _(standalone::c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(standalone::c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(standalone::c10::Float8_e8m0fnu, Float8_e8m0fnu) + +#define AT_FORALL_COMPLEX_TYPES(_) \ + _(standalone::c10::complex, ComplexFloat) \ + _(standalone::c10::complex, ComplexDouble) + +#define DEFINE_CONSTANT(_, name) \ + constexpr ScalarType k##name = ScalarType::name; + +// NOLINTNEXTLINE(clang-diagnostic-unused-const-variable) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) +#undef DEFINE_CONSTANT + +inline const char* toString(ScalarType t) { +#define DEFINE_CASE(_, name) \ + case ScalarType::name: \ + return #name; + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) + default: + return "UNKNOWN_SCALAR"; + } +#undef DEFINE_CASE +} + +inline size_t elementSize(ScalarType t) { +#define CASE_ELEMENTSIZE_CASE(ctype, name) \ + case ScalarType::name: \ + return sizeof(ctype); + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) + default: + STANDALONE_CHECK(false, "Unknown ScalarType"); + } +#undef CASE_ELEMENTSIZE_CASE +} + +inline bool isIntegralType(ScalarType t, bool includeBool) { + bool isIntegral = + (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || + t == ScalarType::Long || t == ScalarType::Short || + t == ScalarType::UInt16 || t == ScalarType::UInt32 || + t == ScalarType::UInt64); + + return isIntegral || (includeBool && t == ScalarType::Bool); +} + +inline bool isFloat8Type(ScalarType t) { + return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || + t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz || + t == ScalarType::Float8_e8m0fnu; +} + +inline bool isReducedFloatingType(ScalarType t) { + return t == ScalarType::Half || t == ScalarType::BFloat16 || + isFloat8Type(t) || t == ScalarType::Float4_e2m1fn_x2; +} + +inline bool isFloatingType(ScalarType t) { + return t == ScalarType::Double || t == ScalarType::Float || + isReducedFloatingType(t); +} + +inline bool isComplexType(ScalarType t) { + return ( + t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || + t == ScalarType::ComplexDouble); +} + +inline bool isQIntType(ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || + t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || + t == ScalarType::QUInt2x4; +} + +inline bool isBitsType(ScalarType t) { + return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || + t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || + t == ScalarType::Bits16; +} + +inline bool isBarebonesUnsignedType(ScalarType t) { + return t == ScalarType::UInt1 || t == ScalarType::UInt2 || + t == ScalarType::UInt3 || t == ScalarType::UInt4 || + t == ScalarType::UInt5 || t == ScalarType::UInt6 || + t == ScalarType::UInt7 || t == ScalarType::UInt16 || + t == ScalarType::UInt32 || t == ScalarType::UInt64; +} + +inline ScalarType toQIntType(ScalarType t) { + switch (t) { + case ScalarType::Byte: + return ScalarType::QUInt8; + case ScalarType::Char: + return ScalarType::QInt8; + case ScalarType::Int: + return ScalarType::QInt32; + default: + return t; + } +} + +inline ScalarType toUnderlying(ScalarType t) { + switch (t) { + case ScalarType::QUInt8: + case ScalarType::QUInt4x2: + [[fallthrough]]; + case ScalarType::QUInt2x4: + return ScalarType::Byte; + case ScalarType::QInt8: + return ScalarType::Char; + case ScalarType::QInt32: + return ScalarType::Int; + default: + return t; + } +} + +inline bool isSignedType(ScalarType t) { +#define CASE_ISSIGNED(name) \ + case ScalarType::name: \ + return std::numeric_limits<::standalone::c10::impl::ScalarTypeToCPPTypeT< \ + ScalarType::name>>::is_signed; + + // TODO(#146647): If we expect to have numeric_limits for everything, + // let's just have a big macro for the whole thing. + // If we're hardcoding it, let's just use the macro and a "true"/"false" + // below? + switch (t) { + case ScalarType::QInt8: + case ScalarType::QUInt8: + case ScalarType::QInt32: + case ScalarType::QUInt4x2: + case ScalarType::QUInt2x4: + STANDALONE_CHECK(false, "isSignedType not supported for quantized types"); + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: + STANDALONE_CHECK(false, "Bits types are undefined"); + CASE_ISSIGNED(UInt16); + CASE_ISSIGNED(UInt32); + CASE_ISSIGNED(UInt64); + CASE_ISSIGNED(BFloat16); + CASE_ISSIGNED(Float8_e5m2); + CASE_ISSIGNED(Float8_e5m2fnuz); + CASE_ISSIGNED(Float8_e4m3fn); + CASE_ISSIGNED(Float8_e4m3fnuz); + CASE_ISSIGNED(Float8_e8m0fnu); + CASE_ISSIGNED(Byte); + CASE_ISSIGNED(Char); + CASE_ISSIGNED(Short); + CASE_ISSIGNED(Int); + CASE_ISSIGNED(Long); + CASE_ISSIGNED(Half); + CASE_ISSIGNED(Float); + CASE_ISSIGNED(Double); + CASE_ISSIGNED(ComplexHalf); + CASE_ISSIGNED(ComplexFloat); + CASE_ISSIGNED(ComplexDouble); + CASE_ISSIGNED(Bool); + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: + case ScalarType::Float4_e2m1fn_x2: + return true; + case ScalarType::UInt1: + case ScalarType::UInt2: + case ScalarType::UInt3: + case ScalarType::UInt4: + case ScalarType::UInt5: + case ScalarType::UInt6: + case ScalarType::UInt7: + return false; + case ScalarType::Undefined: + case ScalarType::NumOptions: + break; + // Do not add default here, but rather define behavior of every new entry + // here. `-Wswitch-enum` would raise a warning in those cases. + } + STANDALONE_CHECK(false, "Unknown ScalarType ", t); +#undef CASE_ISSIGNED +} + +inline bool isUnderlying(ScalarType type, ScalarType qtype) { + return type == toUnderlying(qtype); +} + +inline ScalarType toRealValueType(ScalarType t) { + switch (t) { + case ScalarType::ComplexHalf: + return ScalarType::Half; + case ScalarType::ComplexFloat: + return ScalarType::Float; + case ScalarType::ComplexDouble: + return ScalarType::Double; + default: + return t; + } +} + +inline ScalarType toComplexType(ScalarType t) { + switch (t) { + case ScalarType::BFloat16: + // BFloat16 has range equivalent to Float, + // so we map it to ComplexFloat. + return ScalarType::ComplexFloat; + case ScalarType::Half: + return ScalarType::ComplexHalf; + case ScalarType::Float: + return ScalarType::ComplexFloat; + case ScalarType::Double: + return ScalarType::ComplexDouble; + case ScalarType::ComplexHalf: + return ScalarType::ComplexHalf; + case ScalarType::ComplexFloat: + return ScalarType::ComplexFloat; + case ScalarType::ComplexDouble: + return ScalarType::ComplexDouble; + default: + STANDALONE_CHECK(false, "Unknown Complex ScalarType for ", t); + } +} + +// see tensor_attributes.rst for detailed explanation and examples +// of casting rules. +inline bool canCast(const ScalarType from, const ScalarType to) { + // We disallow complex -> non complex, e.g., float_tensor *= complex is + // disallowed. + if (isComplexType(from) && !isComplexType(to)) { + return false; + } + // We disallow float -> integral, e.g., int_tensor *= float is disallowed. + if (isFloatingType(from) && isIntegralType(to, false)) { + return false; + } + + // Treat bool as a distinct "category," to be consistent with type promotion + // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same + // category as `bool_tensor`, we would not promote. Differing categories + // implies `bool_tensor += 5` is disallowed. + // + // NB: numpy distinguishes "unsigned" as a category to get the desired + // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because: + // * We don't want the performance hit of checking the runtime sign of + // Scalars. + // * `uint8_tensor + 5 -> int64_tensor` would be undesirable. + if (from != ScalarType::Bool && to == ScalarType::Bool) { + return false; + } + return true; +} + +namespace detail { +constexpr auto u1 = ScalarType::Byte; +constexpr auto i1 = ScalarType::Char; +constexpr auto i2 = ScalarType::Short; +constexpr auto i4 = ScalarType::Int; +constexpr auto i8 = ScalarType::Long; +constexpr auto f2 = ScalarType::Half; +constexpr auto f4 = ScalarType::Float; +constexpr auto f8 = ScalarType::Double; +constexpr auto c2 = ScalarType::ComplexHalf; +constexpr auto c4 = ScalarType::ComplexFloat; +constexpr auto c8 = ScalarType::ComplexDouble; +constexpr auto b1 = ScalarType::Bool; +constexpr auto bf = ScalarType::BFloat16; +constexpr auto ud = ScalarType::Undefined; + +constexpr auto index2dtype = array_of( + u1, + i1, + i2, + i4, + i8, + f2, + f4, + f8, + c2, + c4, + c8, + b1, + bf); + +constexpr std::array(ScalarType::NumOptions)> +calculate_dtype2index() { + std::array(ScalarType::NumOptions)> inverse = {}; + for (int64_t i = 0; i < static_cast(ScalarType::NumOptions); i++) { + inverse[i] = -1; + } + for (int64_t i = 0; i < static_cast(index2dtype.size()); i++) { + inverse[static_cast(index2dtype[i])] = i; + } + return inverse; +} + +constexpr auto dtype2index = calculate_dtype2index(); +} // namespace detail + +inline ScalarType promoteTypes(ScalarType a, ScalarType b) { + using namespace detail; + + // This is generated according to NumPy's promote_types + if (a == ud || b == ud) { + return ScalarType::Undefined; + } + + // If the two types are equal, return that type + if (a == b) { + return a; + } + + // Handle identically equal types + if (isQIntType(a) || isQIntType(b)) { + STANDALONE_CHECK( + false, + "promoteTypes with quantized numbers is not handled yet; figure out " + "what the correct rules should be, offending types: ", + toString(a), + " ", + toString(b)); + } + + if (isBitsType(a) || isBitsType(b)) { + return ScalarType::Undefined; + } + + if (isFloat8Type(a) || isFloat8Type(b)) { + STANDALONE_CHECK( + false, + "Promotion for Float8 Types is not supported, attempted to promote ", + toString(a), + " and ", + toString(b)); + } + + if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) { + // There are two problems with promotion here: + // + // - Our promotion rule for uint8 is inconsistent with Numpy; Numpy + // promotes to uint64, but since we never had uint64 for the longest + // time, we promote to int64. Changing this is BC-breaking + // + // - We must not promote uint64 to int64 because this will overflow. + // + // It'll be a bit of work to fix it, so we're punting on it for now. + // However, float promotion is fine, so we handle that. + if (isFloatingType(a)) { + return a; + } + if (isFloatingType(b)) { + return b; + } + STANDALONE_CHECK( + false, + "Promotion for uint16, uint32, uint64 types is not supported, " + "attempted to promote ", + toString(a), + " and ", + toString(b)); + } + auto ix_a = dtype2index[static_cast(a)]; + STANDALONE_INTERNAL_ASSERT(ix_a != -1); + auto ix_b = dtype2index[static_cast(b)]; + STANDALONE_INTERNAL_ASSERT(ix_b != -1); + + // This table axes must be consistent with index2dtype + // clang-format off + static constexpr std:: + array, index2dtype.size()> + _promoteTypesLookup = {{ + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, + }}; + // clang-format on + return _promoteTypesLookup[ix_a][ix_b]; +} + +inline std::ostream& operator<<( + std::ostream& stream, + standalone::c10::ScalarType scalar_type) { + return stream << toString(scalar_type); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/SizesAndStrides.h b/backends/aoti/slim/c10/core/SizesAndStrides.h new file mode 100644 index 00000000000..aef0ddab171 --- /dev/null +++ b/backends/aoti/slim/c10/core/SizesAndStrides.h @@ -0,0 +1,402 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#define STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 + +namespace standalone::c10 { + +// Packed container for TensorImpl sizes and strides. +// This design improves on the previous approach of using a pair of +// c10::SmallVector by specializing for the operations we +// actually use and enforcing that the number of sizes is the same as +// the number of strides. The memory layout is as follows: +// +// 1 size_t for the size +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer +// to out-of-line array +class SizesAndStrides { + public: + // TODO: different iterator types for sizes & strides to prevent + // mixing the two accidentally. + using sizes_iterator = int64_t*; + using sizes_const_iterator = const int64_t*; + using strides_iterator = int64_t*; + using strides_const_iterator = const int64_t*; + + SizesAndStrides() { + size_at_unchecked(0) = 0; + stride_at_unchecked(0) = 1; + } + + ~SizesAndStrides() { + if (STANDALONE_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + } + + SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { + if (STANDALONE_LIKELY(rhs.isInline())) { + copyDataInline(rhs); + } else { + allocateOutOfLineStorage(size_); + copyDataOutline(rhs); + } + } + + bool operator==(const SizesAndStrides& other) const { + if (size_ != other.size_) { + return false; + } + return !( + isInline() + ? std::memcmp( + inlineStorage_, other.inlineStorage_, sizeof(inlineStorage_)) + : std::memcmp( + outOfLineStorage_, + other.outOfLineStorage_, + storageBytes(size_))); + } + + SizesAndStrides& operator=(const SizesAndStrides& rhs) { + if (this == &rhs) { + return *this; + } + if (STANDALONE_LIKELY(rhs.isInline())) { + if (STANDALONE_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + if (isInline()) { + allocateOutOfLineStorage(rhs.size_); + } else { + resizeOutOfLineStorage(rhs.size_); + } + copyDataOutline(rhs); + } + size_ = rhs.size_; + return *this; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { + if (STANDALONE_LIKELY(isInline())) { + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } else { + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + + rhs.size_ = 0; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (STANDALONE_LIKELY(rhs.isInline())) { + if (STANDALONE_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + // They're outline. We're going to steal their vector. + if (!isInline()) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + size_ = rhs.size_; + rhs.size_ = 0; + + return *this; + } + + size_t size() const noexcept { + return size_; + } + + const int64_t* sizes_data() const noexcept { + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + int64_t* sizes_data() noexcept { + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + sizes_const_iterator sizes_begin() const noexcept { + return sizes_data(); + } + + sizes_iterator sizes_begin() noexcept { + return sizes_data(); + } + + sizes_const_iterator sizes_end() const noexcept { + return sizes_begin() + size(); + } + + sizes_iterator sizes_end() noexcept { + return sizes_begin() + size(); + } + + IntArrayRef sizes_arrayref() const noexcept { + return IntArrayRef{sizes_data(), size()}; + } + + void set_sizes(IntArrayRef newSizes) { + resize(newSizes.size()); + std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); + } + + void set_strides(IntArrayRef strides) { + STANDALONE_INTERNAL_ASSERT(strides.size() == size()); + std::copy(strides.begin(), strides.end(), strides_begin()); + } + + const int64_t* strides_data() const noexcept { + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + int64_t* strides_data() noexcept { + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_begin() const noexcept { + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_iterator strides_begin() noexcept { + if (STANDALONE_LIKELY(isInline())) { + return &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_end() const noexcept { + return strides_begin() + size(); + } + + strides_iterator strides_end() noexcept { + return strides_begin() + size(); + } + + IntArrayRef strides_arrayref() const noexcept { + return IntArrayRef{strides_data(), size()}; + } + + // Size accessors. + int64_t size_at(size_t idx) const noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t& size_at(size_t idx) noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t size_at_unchecked(size_t idx) const noexcept { + return sizes_data()[idx]; + } + + int64_t& size_at_unchecked(size_t idx) noexcept { + return sizes_data()[idx]; + } + + // Size accessors. + int64_t stride_at(size_t idx) const noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t& stride_at(size_t idx) noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t stride_at_unchecked(size_t idx) const noexcept { + return strides_data()[idx]; + } + + int64_t& stride_at_unchecked(size_t idx) noexcept { + return strides_data()[idx]; + } + + void resize(size_t newSize) { + const auto oldSize = size(); + if (newSize == oldSize) { + return; + } + if (STANDALONE_LIKELY( + newSize <= STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE && + isInline())) { + if (oldSize < newSize) { + const auto bytesToZero = + (newSize - oldSize) * sizeof(inlineStorage_[0]); + memset(&inlineStorage_[oldSize], 0, bytesToZero); + memset( + &inlineStorage_ + [STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], + 0, + bytesToZero); + } + size_ = newSize; + } else { + resizeSlowPath(newSize, oldSize); + } + } + + private: + void resizeSlowPath(size_t newSize, size_t oldSize) { + if (newSize <= STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE) { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( + !isInline(), + "resizeSlowPath called when fast path should have been hit!"); + int64_t* tempStorage = outOfLineStorage_; + memcpy( + &inlineStorage_[0], + &tempStorage[0], + STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE * + sizeof(inlineStorage_[0])); + memcpy( + &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + &tempStorage[oldSize], + STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE * + sizeof(inlineStorage_[0])); + // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_ + // HAS BEEN OVERWRITTEN! + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(tempStorage); + } else { + if (isInline()) { + // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD + // OVERWRITE inlineStorage_! + int64_t* tempStorage = + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + static_cast(malloc(storageBytes(newSize))); + STANDALONE_CHECK( + tempStorage, + "Could not allocate memory to change Tensor SizesAndStrides!"); + const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]); + const auto bytesToZero = (newSize > oldSize) + ? (newSize - oldSize) * sizeof(tempStorage[0]) + : 0; + memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[oldSize], 0, bytesToZero); + } + memcpy( + &tempStorage[newSize], + &inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE], + bytesToCopy); + if (bytesToZero) { + memset(&tempStorage[newSize + oldSize], 0, bytesToZero); + } + outOfLineStorage_ = tempStorage; + } else { + const bool isGrowing = oldSize < newSize; + if (isGrowing) { + // Resize before shifting so that we have room. + resizeOutOfLineStorage(newSize); + } + // Shift the old strides to their new starting point. Note + // that this does not occur in the inline path above because + // the stride starting point is not moving. + memmove( + outOfLineStorage_ + newSize, + outOfLineStorage_ + oldSize, + std::min(oldSize, newSize) * sizeof(outOfLineStorage_[0])); + if (!isGrowing) { + // Resize after shifting so that we don't lose data. + resizeOutOfLineStorage(newSize); + } else { + // Zero the end of the sizes portion. + const auto bytesToZero = + (newSize - oldSize) * sizeof(outOfLineStorage_[0]); + memset(&outOfLineStorage_[oldSize], 0, bytesToZero); + memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero); + } + } + } + size_ = newSize; + } + + bool isInline() const noexcept { + return size_ <= STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + } + + void copyDataInline(const SizesAndStrides& rhs) { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } + + static size_t storageBytes(size_t size) noexcept { + return size * 2 * sizeof(int64_t); + } + + void allocateOutOfLineStorage(size_t size) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + STANDALONE_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void resizeOutOfLineStorage(size_t newSize) { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); + outOfLineStorage_ = static_cast( + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + realloc(outOfLineStorage_, storageBytes(newSize))); + STANDALONE_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void copyDataOutline(const SizesAndStrides& rhs) noexcept { + memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + } + + size_t size_{1}; + union { + int64_t* outOfLineStorage_; + // NOLINTNEXTLINE(*c-array*) + int64_t inlineStorage_[STANDALONE_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + }; +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/core/WrapDimMinimal.h b/backends/aoti/slim/c10/core/WrapDimMinimal.h new file mode 100644 index 00000000000..651421e6d89 --- /dev/null +++ b/backends/aoti/slim/c10/core/WrapDimMinimal.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include + +// Different from the original implementation in c10, we don't need +// to support SymInt here. +namespace standalone::c10 { +namespace detail { +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); +} + +template +T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { + // Inline the fast paths + if (STANDALONE_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { + // For SymInts, we want an explicit control flow to trigger a guard, so we + // may as well branch too. + if (dim < 0) { + return dim + dim_post_expr; + } + return dim; + } + // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) + return standalone::c10::detail::maybe_wrap_dim_slow( + std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +inline int64_t +maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar = true) { + return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); +} + +namespace detail { +// This template can only be specialized at int64_t and c10::SymInt; +// you'll get linker errors otherwise +template +T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { + STANDALONE_CHECK( + dim_post_expr >= 0, "Rank cannot be negative but got ", dim_post_expr); + + if (dim_post_expr == 0) { + STANDALONE_CHECK( + wrap_scalar, + "Dimension specified as ", + dim, + " but tensor has no dimensions"); + return standalone::c10::maybe_wrap_dim( + std::move(dim), + /*dim_post_expr=*/1, + /*wrap_scalar=*/false); + } + + T min = dim_post_expr * -1; + T max = dim_post_expr - 1; + STANDALONE_CHECK( + min <= dim && dim <= max, + "Dimension out of range (expected to be in range of [", + min, + ", ", + max, + "], but got ", + dim, + ")"); + + STANDALONE_INTERNAL_ASSERT( + false, "should never reach here as dim should be out-of-bounds"); +} +} // namespace detail +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/cuda/Exception.h b/backends/aoti/slim/c10/cuda/Exception.h new file mode 100644 index 00000000000..bd972c1652d --- /dev/null +++ b/backends/aoti/slim/c10/cuda/Exception.h @@ -0,0 +1,29 @@ +#pragma once +#ifdef USE_CUDA + +#include +#include +#include + +#include +#include +#include + +#include + +#define STANDALONE_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + STANDALONE_CHECK(__err == cudaSuccess, cudaGetErrorString(__err)); \ + } while (0) + +#define STANDALONE_CUDA_CHECK_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (STANDALONE_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + STANDALONE_WARN("CUDA warning: ", cudaGetErrorString(__err)); \ + } \ + } while (0) + +#endif diff --git a/backends/aoti/slim/c10/macros/Macros.h b/backends/aoti/slim/c10/macros/Macros.h new file mode 100644 index 00000000000..aa8329263fe --- /dev/null +++ b/backends/aoti/slim/c10/macros/Macros.h @@ -0,0 +1,219 @@ +#pragma once + +#include + +// UBSan (Undefined Behavior Sanitizer) macros +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) +#define __ubsan_ignore_signed_int_overflow__ \ + __attribute__((no_sanitize("signed-integer-overflow"))) +#define __ubsan_ignore_pointer_overflow__ \ + __attribute__((no_sanitize("pointer-overflow"))) +#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#define __ubsan_ignore_float_cast_overflow__ \ + __attribute__((no_sanitize("float-cast-overflow"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_undefined__ +#define __ubsan_ignore_signed_int_overflow__ +#define __ubsan_ignore_pointer_overflow__ +#define __ubsan_ignore_function__ +#define __ubsan_ignore_float_cast_overflow__ +#endif + +// STANDALONE_LIKELY/STANDALONE_UNLIKELY +// +// These macros provide parentheses, so you can use these macros as: +// +// if STANDALONE_LIKELY(some_expr) { +// ... +// } +// +// NB: static_cast to boolean is mandatory in C++, because __builtin_expect +// takes a long argument, which means you may trigger the wrong conversion +// without it. +// +#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) +#define STANDALONE_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) +#define STANDALONE_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) +#else +#define STANDALONE_LIKELY(expr) (expr) +#define STANDALONE_UNLIKELY(expr) (expr) +#endif + +// On nvcc, STANDALONE_UNLIKELY thwarts missing return statement analysis. In +// cases where the unlikely expression may be a constant, use this macro to +// ensure return statement analysis keeps working (at the cost of not getting +// the likely/unlikely annotation on nvcc). +// https://github.com/pytorch/pytorch/issues/21418 +// +// Currently, this is only used in the error reporting macros below. If you +// want to use it more generally, move me to Macros.h +// +// TODO: Brian Vaughan observed that we might be able to get this to work on +// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs +// from non-constexpr. Since there isn't any evidence that losing +// STANDALONE_UNLIKELY in nvcc is causing us perf problems, this is not yet +// implemented, but this might be an interesting piece of C++ code for an +// intrepid bootcamper to write. +#if defined(__CUDACC__) +#define STANDALONE_UNLIKELY_OR_CONST(e) e +#else +#define STANDALONE_UNLIKELY_OR_CONST(e) STANDALONE_UNLIKELY(e) +#endif + +#define STANDALONE_STRINGIZE_IMPL(x) #x +#define STANDALONE_STRINGIZE(x) STANDALONE_STRINGIZE_IMPL(x) + +#define STANDALONE_CONCATENATE_IMPL(s1, s2) s1##s2 +#define STANDALONE_CONCATENATE(s1, s2) STANDALONE_CONCATENATE_IMPL(s1, s2) + +/** + * STANDALONE_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts + * with str and ends with a unique number. + */ +#ifdef __COUNTER__ +#define STANDALONE_UID __COUNTER__ +#define STANDALONE_ANONYMOUS_VARIABLE(str) \ + STANDALONE_CONCATENATE(str, __COUNTER__) +#else +#define STANDALONE_UID __LINE__ +#define STANDALONE_ANONYMOUS_VARIABLE(str) STANDALONE_CONCATENATE(str, __LINE__) +#endif + +// Private helper macro for workaround MSVC misexpansion of nested macro +// invocations involving __VA_ARGS__. See +// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly +#define STANDALONE_EXPAND_MSVC_WORKAROUND(x) x + +/// STANDALONE_NOINLINE - Functions whose declaration is annotated with this +/// will not be inlined. +#ifdef __GNUC__ +#define STANDALONE_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define STANDALONE_NOINLINE __declspec(noinline) +#else +#define STANDALONE_NOINLINE +#endif + +#if defined(_MSC_VER) +#define STANDALONE_ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define STANDALONE_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define STANDALONE_ALWAYS_INLINE inline +#endif + +// Unlike STANDALONE_ALWAYS_INLINE, STANDALONE_ALWAYS_INLINE_ATTRIBUTE can be +// used on a lambda. +#if defined(_MSC_VER) +// MSVC 14.39 is reasonably recent and doesn't like +// [[msvc::forceinline]] on a lambda, so don't try to use it. +#define STANDALONE_ALWAYS_INLINE_ATTRIBUTE +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define STANDALONE_ALWAYS_INLINE_ATTRIBUTE __attribute__((__always_inline__)) +#else +#define STANDALONE_ALWAYS_INLINE_ATTRIBUTE +#endif + +#if defined(_MSC_VER) +#define STANDALONE_ATTR_VISIBILITY_HIDDEN +#elif defined(__GNUC__) +#define STANDALONE_ATTR_VISIBILITY_HIDDEN \ + __attribute__((__visibility__("hidden"))) +#else +#define STANDALONE_ATTR_VISIBILITY_HIDDEN +#endif + +#define STANDALONE_ERASE \ + STANDALONE_ALWAYS_INLINE STANDALONE_ATTR_VISIBILITY_HIDDEN + +#include + +#ifdef __HIPCC__ +// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. +// We do this #include here so that STANDALONE_HOST_DEVICE and friends will Just +// Work. See https://github.com/ROCm/hip/issues/441 +#include +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) +// Designates functions callable from the host (CPU) and the device (GPU) +#define STANDALONE_HOST_DEVICE __host__ __device__ +#define STANDALONE_DEVICE __device__ +#define STANDALONE_HOST __host__ +// constants from +// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) +// The maximum number of threads per multiprocessor is 1024 for Turing +// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and +// 2048 for all other architectures. You'll get warnings if you exceed these +// constants. Hence, the following macros adjust the input values from the user +// to resolve potential warnings. +#if __CUDA_ARCH__ == 750 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; +#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; +#else +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; +#endif +// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently +constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; +// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block +// size. 256 is a good number for this fallback and should give good occupancy +// and versatility across all architectures. +constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; +// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it +// turns out that although __launch_bounds__ can take constexpr, it +// can't take a constexpr that has anything to do with templates. +// Currently we use launch_bounds that depend on template arguments in +// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, +// STANDALONE_MAX_THREADS_PER_BLOCK and STANDALONE_MIN_BLOCKS_PER_SM are +// kept as macros. +// Suppose you were planning to write __launch_bounds__(a, b), based on your +// performance tuning on a modern GPU. Instead, you should write +// __launch_bounds__(STANDALONE_MAX_THREADS_PER_BLOCK(a), +// STANDALONE_MIN_BLOCKS_PER_SM(a, b)), which will also properly respect limits +// on old architectures. +#define STANDALONE_MAX_THREADS_PER_BLOCK(val) \ + (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ + : CUDA_THREADS_PER_BLOCK_FALLBACK) +#define STANDALONE_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ + ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ + ? (blocks_per_sm) \ + : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / \ + (threads_per_block)))) +// STANDALONE_LAUNCH_BOUNDS is analogous to __launch_bounds__ +#define STANDALONE_LAUNCH_BOUNDS_0 \ + __launch_bounds__( \ + 256, 4) // default launch bounds that should give good occupancy + // and versatility across all architectures. +#define STANDALONE_LAUNCH_BOUNDS_1(max_threads_per_block) \ + __launch_bounds__((STANDALONE_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) +#define STANDALONE_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ + __launch_bounds__( \ + (STANDALONE_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ + (STANDALONE_MIN_BLOCKS_PER_SM( \ + (max_threads_per_block), (min_blocks_per_sm)))) +#else +#define STANDALONE_HOST_DEVICE +#define STANDALONE_HOST +#define STANDALONE_DEVICE +#endif + +#define _STANDALONE_PRAGMA__(string) _Pragma(#string) +#define _STANDALONE_PRAGMA_(string) _STANDALONE_PRAGMA__(string) + +#ifdef __clang__ +#define STANDALONE_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") +#define STANDALONE_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") +#define STANDALONE_CLANG_DIAGNOSTIC_IGNORE(flag) \ + _STANDALONE_PRAGMA_(clang diagnostic ignored flag) +#define STANDALONE_CLANG_HAS_WARNING(flag) __has_warning(flag) +#else +#define STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#define STANDALONE_CLANG_DIAGNOSTIC_POP() +#define STANDALONE_CLANG_DIAGNOSTIC_IGNORE(flag) +#define STANDALONE_CLANG_HAS_WARNING(flag) 0 +#endif diff --git a/backends/aoti/slim/c10/targets.bzl b/backends/aoti/slim/c10/targets.bzl new file mode 100644 index 00000000000..2bef9f5cf96 --- /dev/null +++ b/backends/aoti/slim/c10/targets.bzl @@ -0,0 +1,31 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define c10 library targets for SlimTensor. + + These are portable c10 utilities adapted from torchnative/standalone. + """ + + # c10 utility headers (ArrayRef, Half, BFloat16, complex, etc.) + # Excludes CUDA-specific headers which require CUDA SDK + runtime.cxx_library( + name = "c10", + exported_headers = glob( + ["**/*.h"], + exclude = ["cuda/**/*.h"], + ), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [], + ) + + # c10 CUDA-specific headers (requires CUDA SDK) + runtime.cxx_library( + name = "c10_cuda", + exported_headers = glob(["cuda/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_preprocessor_flags = ["-DUSE_CUDA"], + exported_deps = [":c10"], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) diff --git a/backends/aoti/slim/c10/util/Array.h b/backends/aoti/slim/c10/util/Array.h new file mode 100644 index 00000000000..39eabc830d1 --- /dev/null +++ b/backends/aoti/slim/c10/util/Array.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace standalone::c10 { + +// This helper function creates a constexpr std::array +// From a compile time list of values, without requiring you to explicitly +// write out the length. +// +// See also https://stackoverflow.com/a/26351760/23845 +template +inline constexpr auto array_of(T&&... t) -> std::array { + return {{std::forward(t)...}}; +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/ArrayRef.h b/backends/aoti/slim/c10/util/ArrayRef.h new file mode 100644 index 00000000000..4a09f7a9335 --- /dev/null +++ b/backends/aoti/slim/c10/util/ArrayRef.h @@ -0,0 +1,371 @@ +//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::ArrayRef. +// removed llvm-specific functionality +// removed some implicit const -> non-const conversions that rely on +// complicated std::enable_if meta-programming +// removed a bunch of slice variants for simplicity... + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace standalone::c10 { +/// ArrayRef - Represent a constant reference to an array (0 or more elements +/// consecutively in memory), i.e. a start pointer and a length. It allows +/// various APIs to take consecutive elements easily and conveniently. +/// +/// This class does not own the underlying data, it is expected to be used in +/// situations where the data resides in some other buffer, whose lifetime +/// extends past that of the ArrayRef. For this reason, it is not in general +/// safe to store an ArrayRef. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +template +class ArrayRef final { + public: + using iterator = const T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + private: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_type Length; + + void debugCheckNullptrInvariant() { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY( + Data != nullptr || Length == 0, + "created ArrayRef with nullptr and non-zero length! std::optional " + "relies on this being illegal"); + } + + public: + /// @name Constructors + /// @{ + + /// Construct an empty ArrayRef. + /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} + + /// Construct an ArrayRef from a single element. + // TODO Make this explicit + constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an ArrayRef from a pointer and length. + constexpr ArrayRef(const T* data, size_t length) + : Data(data), Length(length) { + debugCheckNullptrInvariant(); + } + + /// Construct an ArrayRef from a range. + constexpr ArrayRef(const T* begin, const T* end) + : Data(begin), Length(end - begin) { + debugCheckNullptrInvariant(); + } + + template < + typename Container, + typename U = decltype(std::declval().data()), + typename = std::enable_if_t< + (std::is_same_v || std::is_same_v)>> + /* implicit */ ArrayRef(const Container& container) + : Data(container.data()), Length(container.size()) { + debugCheckNullptrInvariant(); + } + + /// Construct an ArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because ArrayRef can't work on a std::vector + // bitfield. + template + /* implicit */ ArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert( + !std::is_same_v, + "ArrayRef cannot be constructed from a " + "std::vector bitfield."); + } + + /// Construct an ArrayRef from a std::array + template + /* implicit */ constexpr ArrayRef(const std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct an ArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-arrays*) + /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} + + /// Construct an ArrayRef from a std::initializer_list. + /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { + return Data; + } + constexpr iterator end() const { + return Data + Length; + } + + // These are actually the same as iterator, since ArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { + return Data; + } + constexpr const_iterator cend() const { + return Data + Length; + } + + constexpr reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + constexpr reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + /// Check if all elements in the array satisfy the given expression + constexpr bool allMatch(const std::function& pred) const { + return std::all_of(cbegin(), cend(), pred); + } + + /// empty - Check if the array is empty. + constexpr bool empty() const { + return Length == 0; + } + + constexpr const T* data() const { + return Data; + } + + /// size - Get the array size. + constexpr size_t size() const { + return Length; + } + + /// front - Get the first element. + constexpr const T& front() const { + STANDALONE_CHECK( + !empty(), "ArrayRef: attempted to access front() of empty list"); + return Data[0]; + } + + /// back - Get the last element. + constexpr const T& back() const { + STANDALONE_CHECK( + !empty(), "ArrayRef: attempted to access back() of empty list"); + return Data[Length - 1]; + } + + /// equals - Check for element-wise equality. + constexpr bool equals(ArrayRef RHS) const { + return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Take M elements of the array starting at element N + constexpr ArrayRef slice(size_t N, size_t M) const { + STANDALONE_CHECK( + N + M <= size(), + "ArrayRef: invalid slice, N = ", + N, + "; M = ", + M, + "; size = ", + size()); + return ArrayRef(data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + constexpr ArrayRef slice(size_t N) const { + STANDALONE_CHECK( + N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size()); + return slice(N, size() - N); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T& operator[](size_t Index) const { + return Data[Index]; + } + + /// Vector compatibility + constexpr const T& at(size_t Index) const { + STANDALONE_CHECK( + Index < Length, + "ArrayRef: invalid index Index = ", + Index, + "; Length = ", + Length); + return Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { + return std::vector(Data, Data + Length); + } + + /// @} +}; + +template +std::ostream& operator<<(std::ostream& out, ArrayRef list) { + int i = 0; + out << "["; + for (const auto& e : list) { + if (i++ > 0) + out << ", "; + out << e; + } + out << "]"; + return out; +} + +/// @name ArrayRef Convenience constructors +/// @{ + +/// Construct an ArrayRef from a single element. +template +ArrayRef makeArrayRef(const T& OneElt) { + return OneElt; +} + +/// Construct an ArrayRef from a pointer and length. +template +ArrayRef makeArrayRef(const T* data, size_t length) { + return ArrayRef(data, length); +} + +/// Construct an ArrayRef from a range. +template +ArrayRef makeArrayRef(const T* begin, const T* end) { + return ArrayRef(begin, end); +} + +/// Construct an ArrayRef from a std::vector. +template +ArrayRef makeArrayRef(const std::vector& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::array. +template +ArrayRef makeArrayRef(const std::array& Arr) { + return Arr; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) (const) +template +ArrayRef makeArrayRef(const ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) +template +ArrayRef& makeArrayRef(ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a C array. +template +// NOLINTNEXTLINE(*c-arrays*) +ArrayRef makeArrayRef(const T (&Arr)[N]) { + return ArrayRef(Arr); +} + +// WARNING: Template instantiation will NOT be willing to do an implicit +// conversions to get you to an standalone::c10::ArrayRef, which is why we +// need so many overloads. + +template +bool operator==( + standalone::c10::ArrayRef a1, + standalone::c10::ArrayRef a2) { + return a1.equals(a2); +} + +template +bool operator!=( + standalone::c10::ArrayRef a1, + standalone::c10::ArrayRef a2) { + return !a1.equals(a2); +} + +template +bool operator==(const std::vector& a1, standalone::c10::ArrayRef a2) { + return standalone::c10::ArrayRef(a1).equals(a2); +} + +template +bool operator!=(const std::vector& a1, standalone::c10::ArrayRef a2) { + return !standalone::c10::ArrayRef(a1).equals(a2); +} + +template +bool operator==(standalone::c10::ArrayRef a1, const std::vector& a2) { + return a1.equals(standalone::c10::ArrayRef(a2)); +} + +template +bool operator!=(standalone::c10::ArrayRef a1, const std::vector& a2) { + return !a1.equals(standalone::c10::ArrayRef(a2)); +} + +using IntArrayRef = ArrayRef; + +using IntList + [[deprecated("This alias is deprecated because it doesn't make ownership " + "semantics obvious. Use IntArrayRef instead!")]] = + ArrayRef; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/BFloat16-inl.h b/backends/aoti/slim/c10/util/BFloat16-inl.h new file mode 100644 index 00000000000..4608d9a6c54 --- /dev/null +++ b/backends/aoti/slim/c10/util/BFloat16-inl.h @@ -0,0 +1,365 @@ +#pragma once + +#include +#include + +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +namespace standalone::c10 { + +/// Constructors +inline STANDALONE_HOST_DEVICE BFloat16::BFloat16(float value) + : +#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 800 + x(__bfloat16_as_ushort(__float2bfloat16(value))) +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + x(standalone::c10::bit_cast(sycl::ext::oneapi::bfloat16(value))) +#else + // RNE by default + x(detail::round_to_nearest_even(value)) +#endif +{ +} + +/// Implicit conversions +inline STANDALONE_HOST_DEVICE BFloat16::operator float() const { +#if defined(__CUDACC__) && !defined(USE_ROCM) + return __bfloat162float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + return float(*reinterpret_cast(&x)); +#else + return detail::f32_from_bits(x); +#endif +} + +#if defined(__CUDACC__) && !defined(USE_ROCM) +inline STANDALONE_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) +inline STANDALONE_HOST_DEVICE BFloat16::BFloat16( + const sycl::ext::oneapi::bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() + const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline STANDALONE_DEVICE BFloat16 __ldg(const BFloat16* ptr) { +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); +#else + return *ptr; +#endif +} +#endif + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE BFloat16 +operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 +operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 +operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator/( + const BFloat16& a, + const BFloat16& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator+=( + BFloat16& a, + const BFloat16& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator-=( + BFloat16& a, + const BFloat16& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator*=( + BFloat16& a, + const BFloat16& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator/=( + BFloat16& a, + const BFloat16& b) { + a = a / b; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator|( + BFloat16& a, + const BFloat16& b) { + a.x = a.x | b.x; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator^( + BFloat16& a, + const BFloat16& b) { + a.x = a.x ^ b.x; + return a; +} + +inline STANDALONE_HOST_DEVICE BFloat16& operator&( + BFloat16& a, + const BFloat16& b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline STANDALONE_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) > float(rhs); +} + +inline STANDALONE_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) < float(rhs); +} + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr standalone::c10::BFloat16 min() { + return standalone::c10::BFloat16( + 0x0080, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 lowest() { + return standalone::c10::BFloat16( + 0xFF7F, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 max() { + return standalone::c10::BFloat16( + 0x7F7F, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 epsilon() { + return standalone::c10::BFloat16( + 0x3C00, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 round_error() { + return standalone::c10::BFloat16( + 0x3F00, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 infinity() { + return standalone::c10::BFloat16( + 0x7F80, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 quiet_NaN() { + return standalone::c10::BFloat16( + 0x7FC0, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 signaling_NaN() { + return standalone::c10::BFloat16( + 0x7F80, standalone::c10::BFloat16::from_bits()); + } + static constexpr standalone::c10::BFloat16 denorm_min() { + return standalone::c10::BFloat16( + 0x0001, standalone::c10::BFloat16::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/BFloat16-math.h b/backends/aoti/slim/c10/util/BFloat16-math.h new file mode 100644 index 00000000000..f036f309e26 --- /dev/null +++ b/backends/aoti/slim/c10/util/BFloat16-math.h @@ -0,0 +1,332 @@ +#pragma once + +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +namespace standalone::c10 { +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || + std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; +} // namespace standalone::c10 + +namespace std { + +#if !defined(FBCODE_CAFFE2) && !defined(STANDALONE_NODEPRECATED) +using standalone::c10::is_reduced_floating_point; +using standalone::c10::is_reduced_floating_point_v; +#endif // !defined(FBCODE_CAFFE2) && !defined(STANDALONE_NODEPRECATED) + +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T acos(T a) { + return std::acos(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T asin(T a) { + return std::asin(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T atan(T a) { + return std::atan(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T atanh(T a) { + return std::atanh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T erf(T a) { + return std::erf(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T erfc(T a) { + return std::erfc(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T exp(T a) { + return std::exp(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T expm1(T a) { + return std::expm1(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log(T a) { + return std::log(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log10(T a) { + return std::log10(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log1p(T a) { + return std::log1p(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T log2(T a) { + return std::log2(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T ceil(T a) { + return std::ceil(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T cos(T a) { + return std::cos(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T floor(T a) { + return std::floor(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T nearbyint(T a) { + return std::nearbyint(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T sin(T a) { + return std::sin(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T tan(T a) { + return std::tan(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T sinh(T a) { + return std::sinh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T cosh(T a) { + return std::cosh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T tanh(T a) { + return std::tanh(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T trunc(T a) { + return std::trunc(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T lgamma(T a) { + return std::lgamma(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T sqrt(T a) { + return std::sqrt(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T rsqrt(T a) { + return 1.0 / std::sqrt(float(a)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T abs(T a) { + return std::abs(float(a)); +} +#if defined(_MSC_VER) && defined(__CUDACC__) +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), float(b)); +} +#else +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), b); +} +#endif +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T pow(T a, T b) { + return std::pow(float(a), float(b)); +} +template < + typename T, + typename std:: + enable_if_t, int> = 0> +inline T fmod(T a, T b) { + return std::fmod(float(a), float(b)); +} + +/* + The following function is inspired from the implementation in `musl` + Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT + ---------------------------------------------------------------------- + Copyright © 2005-2020 Rich Felker, et al. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------------------------------------------------------------- + */ +template < + typename T, + typename std:: + enable_if_t, int> = 0> +STANDALONE_HOST_DEVICE inline T nextafter(T from, T to) { + // Reference: + // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c + using int_repr_t = uint16_t; + constexpr uint8_t bits = 16; + union { + T f; + int_repr_t i; + } ufrom = {from}, uto = {to}; + + // get a mask to get the sign bit i.e. MSB + int_repr_t sign_mask = int_repr_t{1} << (bits - 1); + + // short-circuit: if either is NaN, return NaN + if (from != from || to != to) { + return from + to; + } + + // short-circuit: if they are exactly the same. + if (ufrom.i == uto.i) { + return from; + } + + // mask the sign-bit to zero i.e. positive + // equivalent to abs(x) + int_repr_t abs_from = ufrom.i & ~sign_mask; + int_repr_t abs_to = uto.i & ~sign_mask; + if (abs_from == 0) { + // if both are zero but with different sign, + // preserve the sign of `to`. + if (abs_to == 0) { + return to; + } + // smallest subnormal with sign of `to`. + ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; + return ufrom.f; + } + + // if abs(from) > abs(to) or sign(from) != sign(to) + if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { + ufrom.i--; + } else { + ufrom.i++; + } + + return ufrom.f; +} + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/BFloat16.h b/backends/aoti/slim/c10/util/BFloat16.h new file mode 100644 index 00000000000..ed6d07f53d0 --- /dev/null +++ b/backends/aoti/slim/c10/util/BFloat16.h @@ -0,0 +1,123 @@ +#pragma once + +// Defines the bloat16 type (brain floating-point). This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. + +#include +#include +#include +#include +#include +#include + +#if defined(__CUDACC__) && !defined(USE_ROCM) +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +namespace standalone::c10 { + +namespace detail { +inline STANDALONE_HOST_DEVICE float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + +#if defined(USE_ROCM) && defined(__HIPCC__) + float* tempRes; + + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + tempRes = reinterpret_cast(&tmp); + res = *tempRes; +#else + std::memcpy(&res, &tmp, sizeof(tmp)); +#endif + + return res; +} + +inline STANDALONE_HOST_DEVICE uint16_t bits_from_f32(float src) { + uint32_t res = 0; + +#if defined(USE_ROCM) && defined(__HIPCC__) + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + uint32_t* tempRes = reinterpret_cast(&src); + res = *tempRes; +#else + std::memcpy(&res, &src, sizeof(res)); +#endif + + return res >> 16; +} + +inline STANDALONE_HOST_DEVICE uint16_t round_to_nearest_even(float src) { +#if defined(USE_ROCM) && defined(__HIPCC__) + if (src != src) { +#elif defined(_MSC_VER) + if (isnan(src)) { +#else + if (std::isnan(src)) { +#endif + return UINT16_C(0x7FC0); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + union { + uint32_t U32; // NOLINT(facebook-hte-BadMemberName) + float F32; // NOLINT(facebook-hte-BadMemberName) + }; + + F32 = src; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); + } +} +} // namespace detail + +struct alignas(2) BFloat16 { + uint16_t x; + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) && defined(__HIPCC__) + STANDALONE_HOST_DEVICE BFloat16() = default; +#else + BFloat16() = default; +#endif + + struct from_bits_t {}; + static constexpr STANDALONE_HOST_DEVICE from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr STANDALONE_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) + : x(bits) {} + /* implicit */ inline STANDALONE_HOST_DEVICE BFloat16(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + +#if defined(__CUDACC__) && !defined(USE_ROCM) + inline STANDALONE_HOST_DEVICE BFloat16(const __nv_bfloat16& value); + explicit inline STANDALONE_HOST_DEVICE operator __nv_bfloat16() const; +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + inline STANDALONE_HOST_DEVICE BFloat16( + const sycl::ext::oneapi::bfloat16& value); + explicit inline STANDALONE_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() + const; +#endif +}; + +inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Exception.h b/backends/aoti/slim/c10/util/Exception.h new file mode 100644 index 00000000000..6ab2bd8aae6 --- /dev/null +++ b/backends/aoti/slim/c10/util/Exception.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include +#include + +// In the standalone version, STANDALONE_CHECK throws std::runtime_error +// instead of standalone::c10::Error. +namespace standalone::c10::detail { +template +std::string torchCheckMsgImpl(const char* /*msg*/, const Args&... args) { + // This is similar to the one in c10/util/Exception.h, but does + // not depend on the more complex c10::str() function. + // ostringstream may support less data types than c10::str(), + // but should be sufficient in the standalone world. + std::ostringstream oss; + ((oss << args), ...); + return oss.str(); +} +inline const char* torchCheckMsgImpl(const char* msg) { + return msg; +} +// If there is just 1 user-provided C-string argument, use it. +inline const char* torchCheckMsgImpl(const char* /*msg*/, const char* args) { + return args; +} +} // namespace standalone::c10::detail + +#define STANDALONE_CHECK_MSG(cond, type, ...) \ + (::standalone::c10::detail::torchCheckMsgImpl( \ + "Expected " #cond \ + " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)", \ + ##__VA_ARGS__)) +#define STANDALONE_CHECK(cond, ...) \ + if (STANDALONE_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(STANDALONE_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + ##__VA_ARGS__)); \ + } +#define STANDALONE_INTERNAL_ASSERT(cond, ...) \ + if (STANDALONE_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(STANDALONE_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + #cond, \ + " INTERNAL ASSERT FAILED: ", \ + ##__VA_ARGS__)); \ + } + +#define WARNING_MESSAGE_STRING(...) \ + ::standalone::c10::detail::torchCheckMsgImpl(__VA_ARGS__) + +#ifdef DISABLE_WARN +#define _STANDALONE_WARN_WITH(...) ((void)0); +#else +#define _STANDALONE_WARN_WITH(...) \ + std::cerr << __func__ << ", " << __FILE__ << ":" << __LINE__ << ", " \ + << WARNING_MESSAGE_STRING(__VA_ARGS__) << std::endl; +#endif + +#define STANDALONE_WARN(...) _STANDALONE_WARN_WITH(__VA_ARGS__); + +#ifdef NDEBUG +// Optimized version - generates no code. +#define STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + while (false) \ + STANDALONE_EXPAND_MSVC_WORKAROUND(STANDALONE_INTERNAL_ASSERT(__VA_ARGS__)) +#else +#define STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + STANDALONE_EXPAND_MSVC_WORKAROUND(STANDALONE_INTERNAL_ASSERT(__VA_ARGS__)) +#endif diff --git a/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h b/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h new file mode 100644 index 00000000000..600e281b583 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float4_e2m1fn_x2.h @@ -0,0 +1,28 @@ +#pragma once +#include + +#include + +/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed +/// into one byte). This is the FP4 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.3.3) +/// +/// Given two high precision values val0 and val1, here is the +/// binary configuration of their packed representation, from MSB to LSB: +/// +/// original value | val1 : val0 +/// ======================================== +/// bit index (MSB==7, LSB==0) | 7654 : 3210 +/// sign/exponent/mantissa | seem : seem +/// + +namespace standalone::c10 { + +struct alignas(1) Float4_e2m1fn_x2 { + uint8_t val_; + Float4_e2m1fn_x2() = default; + STANDALONE_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fn-inl.h b/backends/aoti/slim/c10/util/Float8_e4m3fn-inl.h new file mode 100644 index 00000000000..cc31b82e699 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fn-inl.h @@ -0,0 +1,297 @@ +#pragma once + +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) + : x(detail::fp8e4m3fn_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn::operator float() const { + return detail::fp8e4m3fn_to_fp32_value(x); +} + +/// Special values helper + +inline STANDALONE_HOST_DEVICE bool Float8_e4m3fn::isnan() const { + return (x & 0b01111111) == 0b01111111; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator/( + const Float8_e4m3fn& a, + const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator+=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator-=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator*=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn& operator/=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e4m3fn a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e4m3fn& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e4m3fn& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e4m3fn& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e4m3fn& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e4m3fn a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator+(Float8_e4m3fn a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator-(Float8_e4m3fn a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator*(Float8_e4m3fn a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator/(Float8_e4m3fn a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator+(int64_t a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator-(int64_t a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator*(int64_t a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fn +operator/(int64_t a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e4m3fn to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr standalone::c10::Float8_e4m3fn min() { + return standalone::c10::Float8_e4m3fn( + 0x08, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn lowest() { + return standalone::c10::Float8_e4m3fn( + 0xFE, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn max() { + return standalone::c10::Float8_e4m3fn( + 0x7E, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn epsilon() { + return standalone::c10::Float8_e4m3fn( + 0x20, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn round_error() { + return standalone::c10::Float8_e4m3fn( + 0x30, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn quiet_NaN() { + return standalone::c10::Float8_e4m3fn( + 0x7F, standalone::c10::Float8_e4m3fn::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fn denorm_min() { + return standalone::c10::Float8_e4m3fn( + 0x01, standalone::c10::Float8_e4m3fn::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fn.h b/backends/aoti/slim/c10/util/Float8_e4m3fn.h new file mode 100644 index 00000000000..320a677cbbb --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fn.h @@ -0,0 +1,238 @@ +#pragma once + +/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// bias = 7 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/standalone/c10/util/Half.h + +#include +#include + +#if defined(__cplusplus) +#include +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline STANDALONE_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E4M3FN number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)input << 24; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(nonsign); +#elif defined(__SYCL_DEVICE_ONLY__) + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#elif defined(_MSC_VER) && !defined(__clang__) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#endif + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + /* + * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1, + * the addition overflows it into bit 31, and the subsequent shift turns the + * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number + * is Nan, 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 4 so the exponent (4 bits originally) + * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0x07 + * for fp8e4m3fn number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x78, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FN format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { + /* + * Binary representation of 480.0f, which is the first value + * not representable in fp8e4m3fn range: + * 0 1111 111 - fp8e4m3fn + * 0 10000111 11100000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(1087) << 20; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fn normal range + * into denorm representation + * magic number: ((127 - 7) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(141) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = 0x7f; + } else { + if (f_bits < (UINT32_C(121) << 23)) { + // Input number is smaller than 2^(-6), which is the smallest + // fp8e4m3fn normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fn { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fn() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e4m3fn(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fnuz-inl.h b/backends/aoti/slim/c10/util/Float8_e4m3fnuz-inl.h new file mode 100644 index 00000000000..55a6ce73972 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fnuz-inl.h @@ -0,0 +1,312 @@ +#pragma once + +#include +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value) + : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<4, 3>(x); +} + +/// Special values helper + +inline STANDALONE_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { + return x == 0b10000000; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz operator/( + const Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator+=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator-=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator*=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz& operator/=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e4m3fnuz& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e4m3fnuz& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e4m3fnuz& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e4m3fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(Float8_e4m3fnuz a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(Float8_e4m3fnuz a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(Float8_e4m3fnuz a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(Float8_e4m3fnuz a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(int a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(int a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(int a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(int a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(Float8_e4m3fnuz a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(Float8_e4m3fnuz a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(Float8_e4m3fnuz a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(Float8_e4m3fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator+(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator-(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator*(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz +operator/(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e4m3fnuz to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr standalone::c10::Float8_e4m3fnuz min() { + return standalone::c10::Float8_e4m3fnuz( + 0x08, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz lowest() { + return standalone::c10::Float8_e4m3fnuz( + 0xFF, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz max() { + return standalone::c10::Float8_e4m3fnuz( + 0x7F, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz epsilon() { + return standalone::c10::Float8_e4m3fnuz( + 0x28, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz round_error() { + return standalone::c10::Float8_e4m3fnuz( + 0x38, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz infinity() { + // NaN (no infinities) + return standalone::c10::Float8_e4m3fnuz( + 0x80, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz quiet_NaN() { + return standalone::c10::Float8_e4m3fnuz( + 0x80, standalone::c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e4m3fnuz denorm_min() { + return standalone::c10::Float8_e4m3fnuz( + 0x01, standalone::c10::Float8_e4m3fnuz::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h b/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h new file mode 100644 index 00000000000..ff3c050f018 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e4m3fnuz.h @@ -0,0 +1,138 @@ +#pragma once + +/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as Float8_e4m3fn: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// The key differences versus Float8_e4m3fn are: +/// bias = 8 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { + /* + * Binary representation of 256.0f, which is the first value not representable + * (i.e. the first value which would overflow in to the sign bit, resulting in + * a NaN) in fp8e4m3fnuz range: + * 1 0000 000 - fp8e4m3fnuz + * 0 10000111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range + * into denorm representation + * magic number: ((127 - 8) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s. + return 0x80; + } + + if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) { + // Input exponent is less than -7, the smallest e4m3fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fnuz { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fnuz() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e4m3fnuz(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e4m3fnuz& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e5m2-inl.h b/backends/aoti/slim/c10/util/Float8_e5m2-inl.h new file mode 100644 index 00000000000..c8e90a8aa0d --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2-inl.h @@ -0,0 +1,302 @@ +#pragma once + +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#define EXP_WIDTH_FP8 5 +#define MAN_WIDTH_FP8 2 +#define EXP_BIAS_FP8 15 + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) + : x(detail::fp8e5m2_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e5m2::operator float() const { + return detail::fp8e5m2_to_fp32_value(x); +} + +/// Special values helpers + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2::isnan() const { + return (x & 0b01111111) > 0b01111100; +} + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2::isinf() const { + return (x & 0b01111111) == 0b01111100; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e5m2 +operator+(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 +operator-(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 +operator*(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/( + const Float8_e5m2& a, + const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator+=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator-=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator*=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2& operator/=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e5m2 a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e5m2 a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e5m2 a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e5m2 a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e5m2& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e5m2& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e5m2& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e5m2& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e5m2 a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e5m2 a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e5m2 a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e5m2 a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e5m2 to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr standalone::c10::Float8_e5m2 min() { + return standalone::c10::Float8_e5m2( + 0x4, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 max() { + return standalone::c10::Float8_e5m2( + 0x7B, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 lowest() { + return standalone::c10::Float8_e5m2( + 0xFB, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 epsilon() { + return standalone::c10::Float8_e5m2( + 0x34, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 round_error() { + return standalone::c10::Float8_e5m2( + 0x38, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 infinity() { + return standalone::c10::Float8_e5m2( + 0x7C, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 quiet_NaN() { + return standalone::c10::Float8_e5m2( + 0x7F, standalone::c10::Float8_e5m2::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2 denorm_min() { + return standalone::c10::Float8_e5m2( + 0x01, standalone::c10::Float8_e5m2::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e5m2.h b/backends/aoti/slim/c10/util/Float8_e5m2.h new file mode 100644 index 00000000000..88d1aab0525 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2.h @@ -0,0 +1,147 @@ +#pragma once + +/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// bias = 15 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/standalone/c10/util/Half.h + +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E5M2 format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline STANDALONE_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E5M2 number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEEE|MM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 26-30 24-25 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + uint16_t half_representation = input; + half_representation <<= 8; + return fp16_ieee_to_fp32_value(half_representation); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) { + /* + * Binary representation of fp32 infinity + * 0 11111111 00000000000000000000000 + */ + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + + /* + * Binary representation of 65536.0f, which is the first value + * not representable in fp8e5m2 range: + * 0 11111 00 - fp8e5m2 + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2 normal range + * into denorm representation + * magic number: ((127 - 15) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); + } else { + if (f_bits < (UINT32_C(113) << 23)) { + // Input number is smaller than 2^(-14), which is the smallest + // fp8e5m2 normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint32_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2 { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e5m2(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; + inline STANDALONE_HOST_DEVICE bool isinf() const; +}; + +inline std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e5m2fnuz-inl.h b/backends/aoti/slim/c10/util/Float8_e5m2fnuz-inl.h new file mode 100644 index 00000000000..d2ccac329af --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2fnuz-inl.h @@ -0,0 +1,318 @@ +#pragma once + +#include +#include +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value) + : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<5, 2>(x); +} + +/// Special values helpers + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const { + return x == 0b10000000; +} + +inline STANDALONE_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const { + return false; +} + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz operator/( + const Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a) { + return -static_cast(a); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator+=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator-=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator*=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz& operator/=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=( + float& a, + const Float8_e5m2fnuz& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=( + float& a, + const Float8_e5m2fnuz& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=( + float& a, + const Float8_e5m2fnuz& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=( + float& a, + const Float8_e5m2fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(Float8_e5m2fnuz a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(Float8_e5m2fnuz a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(Float8_e5m2fnuz a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(Float8_e5m2fnuz a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(int a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(int a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(int a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(int a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(Float8_e5m2fnuz a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(Float8_e5m2fnuz a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(Float8_e5m2fnuz a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(Float8_e5m2fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator+(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator-(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator*(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz +operator/(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e5m2fnuz to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr standalone::c10::Float8_e5m2fnuz min() { + return standalone::c10::Float8_e5m2fnuz( + 0x04, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz max() { + return standalone::c10::Float8_e5m2fnuz( + 0x7F, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz lowest() { + return standalone::c10::Float8_e5m2fnuz( + 0xFF, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz epsilon() { + return standalone::c10::Float8_e5m2fnuz( + 0x34, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz round_error() { + return standalone::c10::Float8_e5m2fnuz( + 0x38, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz infinity() { + return standalone::c10::Float8_e5m2fnuz( + 0x80, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + // TODO(future): we are mapping neg_zero to both inf and NaN, this is + // surprising and we should figure out what to do about it. + static constexpr standalone::c10::Float8_e5m2fnuz quiet_NaN() { + return standalone::c10::Float8_e5m2fnuz( + 0x80, standalone::c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr standalone::c10::Float8_e5m2fnuz denorm_min() { + return standalone::c10::Float8_e5m2fnuz( + 0x01, standalone::c10::Float8_e5m2fnuz::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e5m2fnuz.h b/backends/aoti/slim/c10/util/Float8_e5m2fnuz.h new file mode 100644 index 00000000000..c16e5613202 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e5m2fnuz.h @@ -0,0 +1,138 @@ +#pragma once + +/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as e5m2: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// The key differences that e5m2fnuz brings are: +/// bias = 16 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { + /* + * Binary representation of 65536.0f, which is the first value not + * representable (i.e. the first value which would overflow in to the sign + * bit, resulting in a NaN) in fp8e4m3fnuz range: + * 1 00000 00 - fp8e5m2fnuz + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range + * into denormalized representation. + * magic number: ((127 - 16) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s + return 0x80; + } + + if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) { + // Input exponent is less than -15, the smallest e5m2fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2fnuz { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2fnuz() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e5m2fnuz(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; + inline STANDALONE_HOST_DEVICE bool isinf() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e5m2fnuz& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_e8m0fnu-inl.h b/backends/aoti/slim/c10/util/Float8_e8m0fnu-inl.h new file mode 100644 index 00000000000..f510ca551b8 --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e8m0fnu-inl.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include + +// TODO(#146647): Can we remove the below warning? +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Constructors + +inline STANDALONE_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value) + : x(detail::fp8e8m0fnu_from_fp32_value(value)) {} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Float8_e8m0fnu::operator float() const { + // TODO(#146647): maybe rewrite without control flow + + // if exponent is zero, need to special case to return 2^-127 instead of zero + if (x == 0) { + return standalone::c10::detail::fp32_from_bits(0x00400000); + } + + // if exponent is NaN, need to special case to return properly encoded NaN + if (isnan()) { + return standalone::c10::detail::fp32_from_bits(0x7f800001); + } + + // leave sign at 0, set the exponent bits, leave stored mantissa at 0 + uint32_t res = x << 23; + + return standalone::c10::detail::fp32_from_bits(res); +} + +/// Special values helper + +inline STANDALONE_HOST_DEVICE bool Float8_e8m0fnu::isnan() const { + return x == 0b11111111; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Float8_e8m0fnu to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = false; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = false; + static constexpr auto has_denorm_loss = false; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 1; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 1; // just a 2! + static constexpr int radix = 2; + static constexpr int min_exponent = -126; + static constexpr int min_exponent10 = -38; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr standalone::c10::Float8_e8m0fnu min() { + // 2^-127 + return standalone::c10::Float8_e8m0fnu( + 0b00000000, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu lowest() { + // 2^-127 + return standalone::c10::Float8_e8m0fnu( + 0b00000000, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu max() { + // 254 biased, which is 127 unbiased, so 2^127 + return standalone::c10::Float8_e8m0fnu( + 0b11111110, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu epsilon() { + // according to https://en.cppreference.com/w/cpp/types/numeric_limits, this + // is "the difference between 1.0 and the next representable value of the + // given floating-point type". The next representable value is 2.0, so the + // difference is 1.0 which is 2^0. 0 unbiased is 127 biased. + return standalone::c10::Float8_e8m0fnu( + 0b01111111, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu round_error() { + // 0.5 in float, which is 2^-1, and -1 + 127 = 126 + return standalone::c10::Float8_e8m0fnu( + 0b01111110, standalone::c10::Float8_e8m0fnu::from_bits()); + } + static constexpr standalone::c10::Float8_e8m0fnu quiet_NaN() { + return standalone::c10::Float8_e8m0fnu( + 0b11111111, standalone::c10::Float8_e8m0fnu::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Float8_e8m0fnu.h b/backends/aoti/slim/c10/util/Float8_e8m0fnu.h new file mode 100644 index 00000000000..2e2e46d627a --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_e8m0fnu.h @@ -0,0 +1,119 @@ +#pragma once + +/// Defines the Float8_e8m0fnu type (8-bit floating-point) including +/// conversions to standard C types +/// Binary configuration : +/// eeeeeeee +/// no sign bits +/// 8 exponent bits +/// no mantissa bits +/// +/// This is the E8M0 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.4.1) + +#include +#include +#include + +// TODO(#146647): do we need to special case OPENCL? +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 e8m0fnu format, in bit representation. + */ +inline STANDALONE_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) { + // TODO(#146647): maybe rewrite without control flow + + uint32_t f_bits = standalone::c10::detail::fp32_to_bits(f); + + // extract the exponent + uint32_t exponent = (f_bits >> 23) & 0b11111111; + + // special case float32 NaN and +-inf to map to e8m0 nan + if (exponent == 0b11111111) { + return exponent; + } + + // next, we use guard, round, sticky bits and the LSB to implement round to + // nearest, with ties to even + + // guard bit - bit 23, or 22 zero-indexed + uint8_t g = (f_bits & 0x400000) > 0; + // round bit - bit 22, or 21 zero-indexed + uint8_t r = (f_bits & 0x200000) > 0; + // sticky bit - bits 21 to 1, or 20 to 0 zero-indexed + uint8_t s = (f_bits & 0x1FFFFF) > 0; + // in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the + // original float32 is denormal, and to 1 if the original float32 is normal. + uint8_t lsb = exponent > 0; + + // implement the RNE logic + bool round_up = false; + + // if g == 0, round down (no-op) + if (g == 1) { + if ((r == 1) || (s == 1)) { + // round up + round_up = true; + } else { + if (lsb == 1) { + // round up + round_up = true; + } + // if lsb == 0, round down (no-op) + } + } + + if (round_up) { + // adjust exponent + // note that if exponent was 255 we would have already returned earlier, so + // we know we can add one safely without running out of bounds + exponent++; + } + + return exponent; +} + +} // namespace detail + +struct alignas(1) Float8_e8m0fnu { + uint8_t x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e8m0fnu() = default; + + constexpr STANDALONE_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t) + : x(bits) {} + inline STANDALONE_HOST_DEVICE Float8_e8m0fnu(float value); + inline STANDALONE_HOST_DEVICE operator float() const; + inline STANDALONE_HOST_DEVICE bool isnan() const; +}; + +inline std::ostream& operator<<( + std::ostream& out, + const Float8_e8m0fnu& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/Float8_fnuz_cvt.h b/backends/aoti/slim/c10/util/Float8_fnuz_cvt.h new file mode 100644 index 00000000000..00bfa8cd8fc --- /dev/null +++ b/backends/aoti/slim/c10/util/Float8_fnuz_cvt.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +#include + +#if defined(SYCL_LANGUAGE_VERSION) +#include +#endif + +namespace standalone::c10::detail { + +/* + * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ + * format, in bit representation, to a 32-bit floating-point number. + */ +template +inline STANDALONE_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { + static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2)); + constexpr uint32_t weo = 8; + constexpr uint32_t wmo = 23; + + if (x == 0) { + return 0; + } + + if (x == 0x80) { + constexpr uint32_t ifNaN = 0x7F800001; + return fp32_from_bits(ifNaN); + } + + uint32_t mantissa = x & ((1 << wm) - 1); + uint32_t exponent = (x & 0x7F) >> wm; + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(mantissa); +#elif defined(__SYCL_DEVICE_ONLY__) + uint32_t renorm_shift = sycl::clz(mantissa); +#elif defined(_MSC_VER) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(mantissa); +#endif + uint32_t sh = 1 + renorm_shift - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + + const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)); + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + uint32_t sign = x >> 7; + uint32_t retval = (sign << 31) | (exponent << 23) | mantissa; + return fp32_from_bits(retval); +} + +} // namespace standalone::c10::detail diff --git a/backends/aoti/slim/c10/util/Half-inl.h b/backends/aoti/slim/c10/util/Half-inl.h new file mode 100644 index 00000000000..05fa6349f81 --- /dev/null +++ b/backends/aoti/slim/c10/util/Half-inl.h @@ -0,0 +1,351 @@ +#pragma once + +#include +#include + +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +// TODO: add contents in ATen/cpu/vec/vec_half.h +// #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ +// !defined(__APPLE__) +// #include +// #endif + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +#if defined(__aarch64__) && !defined(__CUDACC__) +/// Constructors +inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {} +inline Half::operator float16_t() const { + return detail::fp16_from_bits(x); +} +#else + +inline STANDALONE_HOST_DEVICE Half::Half(float value) + : +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + x(__half_as_short(__float2half(value))) +#elif defined(__SYCL_DEVICE_ONLY__) + x(standalone::c10::bit_cast(sycl::half(value))) +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + x(at::vec::float2half_scalar(value)) +#else + x(detail::fp16_ieee_from_fp32_value(value)) +#endif +{ +} + +/// Implicit conversions + +inline STANDALONE_HOST_DEVICE Half::operator float() const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __half2float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) + return float(standalone::c10::bit_cast(x)); +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + return at::vec::half2float_scalar(x); +#elif defined(__aarch64__) && !defined(__CUDACC__) + return detail::native_fp16_to_fp32_value(x); +#else + return detail::fp16_ieee_to_fp32_value(x); +#endif +} + +#endif /* !defined(__aarch64__) || defined(__CUDACC__) \ + */ + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline STANDALONE_HOST_DEVICE Half::Half(const __half& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE Half::operator __half() const { + return *reinterpret_cast(&x); +} +#endif + +#ifdef SYCL_LANGUAGE_VERSION +inline STANDALONE_HOST_DEVICE Half::Half(const sycl::half& value) { + x = *reinterpret_cast(&value); +} +inline STANDALONE_HOST_DEVICE Half::operator sycl::half() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ + (defined(__clang__) && defined(__CUDA__)) +inline __device__ Half __ldg(const Half* ptr) { + return __ldg(reinterpret_cast(ptr)); +} +#endif + +/// Arithmetic + +inline STANDALONE_HOST_DEVICE Half operator+(const Half& a, const Half& b) { + return static_cast(a) + static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator-(const Half& a, const Half& b) { + return static_cast(a) - static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator*(const Half& a, const Half& b) { + return static_cast(a) * static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator/(const Half& a, const Half& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator-(const Half& a) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + defined(__HIP_DEVICE_COMPILE__) + return __hneg(a); +#elif defined(__SYCL_DEVICE_ONLY__) + return -standalone::c10::bit_cast(a); +#else + return -static_cast(a); +#endif +} + +inline STANDALONE_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { + a = a + b; + return a; +} + +inline STANDALONE_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { + a = a - b; + return a; +} + +inline STANDALONE_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { + a = a * b; + return a; +} + +inline STANDALONE_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline STANDALONE_HOST_DEVICE float operator+(Half a, float b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE float operator-(Half a, float b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE float operator*(Half a, float b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE float operator/(Half a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE float operator+(float a, Half b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator-(float a, Half b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator*(float a, Half b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE float operator/(float a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE float& operator+=(float& a, const Half& b) { + return a += static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator-=(float& a, const Half& b) { + return a -= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator*=(float& a, const Half& b) { + return a *= static_cast(b); +} +inline STANDALONE_HOST_DEVICE float& operator/=(float& a, const Half& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline STANDALONE_HOST_DEVICE double operator+(Half a, double b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE double operator-(Half a, double b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE double operator*(Half a, double b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE double operator/(Half a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline STANDALONE_HOST_DEVICE double operator+(double a, Half b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator-(double a, Half b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator*(double a, Half b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE double operator/(double a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline STANDALONE_HOST_DEVICE Half operator+(Half a, int b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator-(Half a, int b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator*(Half a, int b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator/(Half a, int b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator+(int a, Half b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Half operator-(int a, Half b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Half operator*(int a, Half b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Half operator/(int a, Half b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline STANDALONE_HOST_DEVICE Half operator+(Half a, int64_t b) { + return a + static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator-(Half a, int64_t b) { + return a - static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator*(Half a, int64_t b) { + return a * static_cast(b); +} +inline STANDALONE_HOST_DEVICE Half operator/(Half a, int64_t b) { + return a / static_cast(b); +} + +inline STANDALONE_HOST_DEVICE Half operator+(int64_t a, Half b) { + return static_cast(a) + b; +} +inline STANDALONE_HOST_DEVICE Half operator-(int64_t a, Half b) { + return static_cast(a) - b; +} +inline STANDALONE_HOST_DEVICE Half operator*(int64_t a, Half b) { + return static_cast(a) * b; +} +inline STANDALONE_HOST_DEVICE Half operator/(int64_t a, Half b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from standalone::c10::Half to float. + +} // namespace standalone::c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = true; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 11; + static constexpr int digits10 = 3; + static constexpr int max_digits10 = 5; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + static constexpr standalone::c10::Half min() { + return standalone::c10::Half(0x0400, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half lowest() { + return standalone::c10::Half(0xFBFF, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half max() { + return standalone::c10::Half(0x7BFF, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half epsilon() { + return standalone::c10::Half(0x1400, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half round_error() { + return standalone::c10::Half(0x3800, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half infinity() { + return standalone::c10::Half(0x7C00, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half quiet_NaN() { + return standalone::c10::Half(0x7E00, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half signaling_NaN() { + return standalone::c10::Half(0x7D00, standalone::c10::Half::from_bits()); + } + static constexpr standalone::c10::Half denorm_min() { + return standalone::c10::Half(0x0001, standalone::c10::Half::from_bits()); + } +}; + +} // namespace std + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/Half.h b/backends/aoti/slim/c10/util/Half.h new file mode 100644 index 00000000000..86f8d8683e0 --- /dev/null +++ b/backends/aoti/slim/c10/util/Half.h @@ -0,0 +1,424 @@ +#pragma once + +/// Defines the Half type (half-precision floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32, instead of using CUDA half intrinsics. +/// Most uses of this type within ATen are memory bound, including the +/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. +/// If you are writing a compute bound kernel, you can use the CUDA half +/// intrinsics directly on the Half type from device code. + +#include +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +#if defined(__aarch64__) && !defined(__CUDACC__) +#include +#endif + +#if defined(__GNUC__) || defined(__clang__) +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \ + defined(_M_IX86) +#if defined(__F16C__) && \ + !(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \ + defined(__HIP_DEVICE_COMPILE__)) +#define STANDALONE_X86_F16 1 +#include // import conversion ops from f16cintrin.h +#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__) + // || defined(__HIP_DEVICE_COMPILE__)) +#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86 +#endif // __GNUC__ || __clang__ + +namespace standalone::c10 { + +namespace detail { + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +STANDALONE_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { +#ifdef STANDALONE_X86_F16 + return _cvtsh_ss(h); +#else + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits + * of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become + * mantissa and exponent of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias + * between single-precision and half-precision formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after + * conversion to the single-precision number. Therefore, if the biased + * exponent of the half-precision input was 0x1F (max possible value), the + * biased exponent of the single-precision output must be 0xFF (max possible + * value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset + * below) rather than by 0x70 suggested by the difference in the exponent bias + * (see above). + * - Then we multiply the single-precision result of exponent adjustment by + * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the + * necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and + * NaN would retain their value on at least partially IEEE754-compliant + * implementations. + * + * Note that the above operations do not handle denormal inputs (where biased + * exponent == 0). However, they also do not operate on denormal inputs, and + * do not produce denormal results. + */ + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + float exp_scale_val = 0; +#if defined(_MSC_VER) && defined(__clang__) + __builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); +#else + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); +#endif + + const float exp_scale = exp_scale_val; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results + * (always normalized). Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has + * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the + * same mantissa and thehalf-precision input and with an exponent which would + * scale the corresponding mantissa bits to 2**(-24). A normalized + * single-precision floating-point number is represented as: FP32 = (1 + + * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased + * exponent is 126, a unit change in the mantissa of the input denormalized + * half-precision number causes a change of the constructed single-precision + * number by 2**(-24), i.e. the same amount. + * + * The last step is to adjust the bias of the constructed single-precision + * number. When the input half-precision number is zero, the constructed + * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = + * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed + * single-precision number to get the numerical equivalent of the input + * half-precision number. + */ + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or + * as a denormalized number, depending on the input exponent. The variable + * two_w contains input exponent in bits 27-31, therefore if its smaller than + * 2**27, the input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign + * of the input number. + */ + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +#endif // STANDALONE_X86_F16 +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 16-bit floating-point number in IEEE half-precision format, in bit + * representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline uint16_t fp16_ieee_from_fp32_value(float f) { +#ifdef STANDALONE_X86_F16 + return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT); +#else + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val = 0, scale_to_zero_val = 0; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy( + &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + +#if defined(_MSC_VER) && _MSC_VER == 1916 + float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; +#else + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +#endif + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | + (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); +#endif // STANDALONE_X86_F16 +} + +#ifdef STANDALONE_X86_F16 +#undef STANDALONE_X86_F16 +#endif // STANDALONE_X86_F16 + +#if defined(__aarch64__) && !defined(__CUDACC__) +inline float16_t fp16_from_bits(uint16_t h) { + return standalone::c10::bit_cast(h); +} + +inline uint16_t fp16_to_bits(float16_t f) { + return standalone::c10::bit_cast(f); +} + +// According to https://godbolt.org/z/frExdbsWG it would translate to single +// fcvt s0, h0 +inline float native_fp16_to_fp32_value(uint16_t h) { + return static_cast(fp16_from_bits(h)); +} + +inline uint16_t native_fp16_from_fp32_value(float f) { + return fp16_to_bits(static_cast(f)); +} +#endif + +} // namespace detail + +struct alignas(2) Half { + unsigned short x; + + struct from_bits_t {}; + STANDALONE_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) + STANDALONE_HOST_DEVICE Half() = default; +#else + Half() = default; +#endif + + constexpr STANDALONE_HOST_DEVICE Half(unsigned short bits, from_bits_t) + : x(bits) {} +#if defined(__aarch64__) && !defined(__CUDACC__) + inline Half(float16_t value); + inline operator float16_t() const; +#else + inline STANDALONE_HOST_DEVICE Half(float value); + inline STANDALONE_HOST_DEVICE operator float() const; +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline STANDALONE_HOST_DEVICE Half(const __half& value); + inline STANDALONE_HOST_DEVICE operator __half() const; +#endif +#ifdef SYCL_LANGUAGE_VERSION + inline STANDALONE_HOST_DEVICE Half(const sycl::half& value); + inline STANDALONE_HOST_DEVICE operator sycl::half() const; +#endif +}; + +inline std::ostream& operator<<(std::ostream& out, const Half& value) { + out << (float)value; + return out; +} + +} // namespace standalone::c10 + +#include // IWYU pragma: keep diff --git a/backends/aoti/slim/c10/util/StringUtil.h b/backends/aoti/slim/c10/util/StringUtil.h new file mode 100644 index 00000000000..ff7c591e734 --- /dev/null +++ b/backends/aoti/slim/c10/util/StringUtil.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace standalone::c10 { +template +inline std::string Join(const std::string& delimiter, const Container& v) { + std::stringstream s; + int cnt = static_cast(v.size()) - 1; + for (auto i = v.begin(); i != v.end(); ++i, --cnt) { + s << (*i) << (cnt ? delimiter : ""); + } + return std::move(s).str(); +} +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/TypeCast.h b/backends/aoti/slim/c10/util/TypeCast.h new file mode 100644 index 00000000000..cfaaaebec95 --- /dev/null +++ b/backends/aoti/slim/c10/util/TypeCast.h @@ -0,0 +1,236 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +template +struct needs_real { + constexpr static bool value = + (is_complex::value && !is_complex::value); +}; + +template +struct maybe_real { + STANDALONE_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_real { + STANDALONE_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + return src.real(); + } +}; + +template +struct maybe_bool { + STANDALONE_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_bool { + STANDALONE_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + // Don't use bool operator so as to to also compile for ComplexHalf. + return src.real() || src.imag(); + } +}; + +// Note: deliberately ignores undefined behavior, consistent with NumPy. +// PyTorch's type conversions can cause a variety of undefined behavior, +// including float to integral overflow and signed to unsigned integer overflow. +// Some of this undefined behavior is addressed below. +template +struct static_cast_with_inter_type { + STANDALONE_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply( + src_t src) { + constexpr bool real = needs_real::value; + auto r = maybe_real::apply(src); + return static_cast(r); + } +}; + +// Partial template specialization for casting to bool. +// Need to handle complex types separately, as we don't +// simply want to cast the real part to bool. +template +struct static_cast_with_inter_type { + STANDALONE_HOST_DEVICE static inline bool apply(src_t src) { + constexpr bool complex = needs_real::value; + return static_cast(maybe_bool::apply(src)); + } +}; + +// Partial template instantiation for casting to uint8. +// Note: Converting from negative float values to unsigned integer types is +// undefined behavior in C++, and current CPU and GPU compilers exhibit +// divergent behavior. Casting from negative float values to signed +// integer types and then to unsigned integer types is not undefined, +// however, so this cast improves the consistency of type conversions +// to uint8 across compilers. +// Further note: Type conversions across compilers still have other undefined +// and divergent behavior. +template +struct static_cast_with_inter_type { + STANDALONE_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply( + src_t src) { + constexpr bool real = needs_real::value; + return static_cast( + static_cast(maybe_real::apply(src))); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::BFloat16> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::BFloat16 src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e5m2> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e5m2 src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e5m2fnuz> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e5m2fnuz src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e4m3fn> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e4m3fn src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e4m3fnuz> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e4m3fnuz src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +// TODO(#146647): Can we make all these template specialization happen +// based off our apply macros? +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Float8_e8m0fnu> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Float8_e8m0fnu src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::Half> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::Half src) { + return static_cast>( + standalone::c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + standalone::c10::complex, + standalone::c10::complex> { + STANDALONE_HOST_DEVICE + __ubsan_ignore_undefined__ static inline standalone::c10::complex< + standalone::c10::Half> + apply(standalone::c10::complex src) { + return static_cast>( + static_cast>(src)); + } +}; + +template +STANDALONE_HOST_DEVICE To convert(From f) { + return static_cast_with_inter_type::apply(f); +} + +// Define separately to avoid being inlined and prevent code-size bloat +[[noreturn]] inline void report_overflow(const char* name) { + std::ostringstream oss; + oss << "value cannot be converted to type " << name << " without overflow"; + throw std::runtime_error(oss.str()); // rather than domain_error (issue 33562) +} + +template +To checked_convert(From f, const char* name) { + // Converting to bool can't overflow so we exclude this case from checking. + if (!std::is_same_v && + overflows(f, /* strict_unsigned */ !std::is_signed_v)) { + report_overflow(name); + } + return convert(f); +} + +} // namespace standalone::c10 + +STANDALONE_CLANG_DIAGNOSTIC_POP() + +// Trigger tests for D25440771. TODO: Remove this line any time you want. diff --git a/backends/aoti/slim/c10/util/TypeSafeSignMath.h b/backends/aoti/slim/c10/util/TypeSafeSignMath.h new file mode 100644 index 00000000000..276b1cee7d0 --- /dev/null +++ b/backends/aoti/slim/c10/util/TypeSafeSignMath.h @@ -0,0 +1,141 @@ +#pragma once + +#include + +#include +#include + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wstring-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion") +#endif +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace standalone::c10 { + +/// Returns false since we cannot have x < 0 if x is unsigned. +template +inline constexpr bool is_negative( + const T& /*x*/, + std::true_type /*is_unsigned*/) { + return false; +} + +/// Returns true if a signed variable x < 0 +template +inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) { + return x < T(0); +} + +/// Returns true if x < 0 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, standalone::c10::Half does not :-( +template +inline constexpr bool is_negative(const T& x) { + return is_negative(x, std::is_unsigned()); +} + +/// Returns the sign of an unsigned variable x as 0, 1 +template +inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { + return T(0) < x; +} + +/// Returns the sign of a signed variable x as -1, 0, 1 +template +inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) { + return (T(0) < x) - (x < T(0)); +} + +/// Returns the sign of x as -1, 0, 1 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, standalone::c10::Half does not :-( +template +inline constexpr int signum(const T& x) { + return signum(x, std::is_unsigned()); +} + +/// Returns true if a and b are not both negative +template +inline constexpr bool signs_differ(const T& a, const U& b) { + return is_negative(a) != is_negative(b); +} + +// Suppress sign compare warning when compiling with GCC +// as later does not account for short-circuit rule before +// raising the warning, see https://godbolt.org/z/Tr3Msnz99 +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif + +/// Returns true if x is greater than the greatest value of the type Limit +template +inline constexpr bool greater_than_max(const T& x) { + constexpr bool can_overflow = + std::numeric_limits::digits > std::numeric_limits::digits; + return can_overflow && x > std::numeric_limits::max(); +} + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +/// Returns true if x < lowest(Limit). Standard comparison +template +inline constexpr bool less_than_lowest( + const T& x, + std::false_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < std::numeric_limits::lowest(); +} + +/// Returns false since all the limit is signed and therefore includes +/// negative values but x cannot be negative because it is unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::false_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x < 0, where 0 is constructed from T. +/// Limit is not signed, so its lower value is zero +template +inline constexpr bool less_than_lowest( + const T& x, + std::true_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < T(0); +} + +/// Returns false sign both types are unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::true_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x is less than the lowest value of type T +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, standalone::c10::Half does not : +template +inline constexpr bool less_than_lowest(const T& x) { + return less_than_lowest( + x, std::is_unsigned(), std::is_unsigned()); +} + +} // namespace standalone::c10 + +STANDALONE_CLANG_DIAGNOSTIC_POP() diff --git a/backends/aoti/slim/c10/util/accumulate.h b/backends/aoti/slim/c10/util/accumulate.h new file mode 100644 index 00000000000..4972dd9826a --- /dev/null +++ b/backends/aoti/slim/c10/util/accumulate.h @@ -0,0 +1,125 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace standalone::c10 { + +/// Sum of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t sum_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), container.end(), static_cast(0)); +} + +/// Sum of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t sum_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate(begin, end, static_cast(0)); +} + +/// Product of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t multiply_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), + container.end(), + static_cast(1), + std::multiplies<>()); +} + +/// Product of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t multiply_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + begin, end, static_cast(1), std::multiplies<>()); +} + +/// Return product of all dimensions starting from k +/// Returns 1 if k>=dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_from_dim(const int k, const C& dims) { + STANDALONE_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0); + + if (k > static_cast(dims.size())) { + return 1; + } else { + auto cbegin = dims.cbegin(); + std::advance(cbegin, k); + return multiply_integers(cbegin, dims.cend()); + } +} + +/// Product of all dims up to k (not including dims[k]) +/// Throws an error if k>dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_to_dim(const int k, const C& dims) { + STANDALONE_INTERNAL_ASSERT(0 <= k); + STANDALONE_INTERNAL_ASSERT((unsigned)k <= dims.size()); + + auto cend = dims.cbegin(); + std::advance(cend, k); + return multiply_integers(dims.cbegin(), cend); +} + +/// Product of all dims between k and l (including dims[k] and excluding +/// dims[l]) k and l may be supplied in either order +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_between_dim(int k, int l, const C& dims) { + STANDALONE_INTERNAL_ASSERT(0 <= k); + STANDALONE_INTERNAL_ASSERT(0 <= l); + + if (k > l) { + std::swap(k, l); + } + + STANDALONE_INTERNAL_ASSERT((unsigned)l < dims.size()); + + auto cbegin = dims.cbegin(); + auto cend = dims.cbegin(); + std::advance(cbegin, k); + std::advance(cend, l); + return multiply_integers(cbegin, cend); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/bit_cast.h b/backends/aoti/slim/c10/util/bit_cast.h new file mode 100644 index 00000000000..765ec641486 --- /dev/null +++ b/backends/aoti/slim/c10/util/bit_cast.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#if __has_include() && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L) +#include +#define STANDALONE_HAVE_STD_BIT_CAST 1 +#else +#define STANDALONE_HAVE_STD_BIT_CAST 0 +#endif // __has_include() && (__cplusplus >= 202002L || + // (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)) + +namespace standalone::c10 { + +#if STANDALONE_HAVE_STD_BIT_CAST +using std::bit_cast; +#else +// Implementations of std::bit_cast() from C++ 20. +// +// This is a less sketchy version of reinterpret_cast. +// +// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more +// information as well as the source of our implementations. +template +std::enable_if_t< + sizeof(To) == sizeof(From) && std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + To> +// constexpr support needs compiler magic +bit_cast(const From& src) noexcept { + static_assert( + std::is_trivially_constructible_v, + "This implementation additionally requires " + "destination type to be trivially constructible"); + + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} +#endif // STANDALONE_HAVE_STD_BIT_CAST +#undef STANDALONE_HAVE_STD_BIT_CAST + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/bits.h b/backends/aoti/slim/c10/util/bits.h new file mode 100644 index 00000000000..2d365463a01 --- /dev/null +++ b/backends/aoti/slim/c10/util/bits.h @@ -0,0 +1,61 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits1x8 { + using underlying = uint8_t; + uint8_t val_; + bits1x8() = default; + STANDALONE_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {} +}; + +/** + * bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits2x4 { + using underlying = uint8_t; + uint8_t val_; + bits2x4() = default; + STANDALONE_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {} +}; + +/** + * bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits4x2 { + using underlying = uint8_t; + uint8_t val_; + bits4x2() = default; + STANDALONE_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {} +}; + +/** + * bits8 is an uninterpreted dtype of a tensor with 8 bits, without any + * semantics defined. + */ +struct alignas(1) bits8 { + uint8_t val_; + bits8() = default; + STANDALONE_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {} +}; + +/** + * bits16 is an uninterpreted dtype of a tensor with 16 bits, without any + * semantics defined. + */ +struct alignas(2) bits16 { + uint16_t val_; + bits16() = default; + STANDALONE_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/complex.h b/backends/aoti/slim/c10/util/complex.h new file mode 100644 index 00000000000..988e446b3e4 --- /dev/null +++ b/backends/aoti/slim/c10/util/complex.h @@ -0,0 +1,690 @@ +#pragma once + +#include + +#include +#include + +#if defined(__CUDACC__) || defined(__HIPCC__) +#include +#endif + +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if STANDALONE_CLANG_HAS_WARNING("-Wfloat-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") +#endif + +namespace standalone::c10 { + +// standalone::c10::complex is an implementation of complex numbers that aims +// to work on all devices supported by PyTorch +// +// Most of the APIs duplicates std::complex +// Reference: https://en.cppreference.com/w/cpp/numeric/complex +// +// [NOTE: Complex Operator Unification] +// Operators currently use a mix of std::complex, thrust::complex, and +// standalone::c10::complex internally. The end state is that all operators +// will use standalone::c10::complex internally. Until then, there may be +// some hacks to support all variants. +// +// +// [Note on Constructors] +// +// The APIs of constructors are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/complex +// +// Since C++14, all constructors are constexpr in std::complex +// +// There are three types of constructors: +// - initializing from real and imag: +// `constexpr complex( const T& re = T(), const T& im = T() );` +// - implicitly-declared copy constructor +// - converting constructors +// +// Converting constructors: +// - std::complex defines converting constructor between float/double/long +// double, +// while we define converting constructor between float/double. +// - For these converting constructors, upcasting is implicit, downcasting is +// explicit. +// - We also define explicit casting from std::complex/thrust::complex +// - Note that the conversion from thrust is not constexpr, because +// thrust does not define them as constexpr ???? +// +// +// [Operator =] +// +// The APIs of operator = are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D +// +// Since C++20, all operator= are constexpr. Although we are not building with +// C++20, we also obey this behavior. +// +// There are three types of assign operator: +// - Assign a real value from the same scalar type +// - In std, this is templated as complex& operator=(const T& x) +// with specialization `complex& operator=(T x)` for float/double/long +// double Since we only support float and double, on will use `complex& +// operator=(T x)` +// - Copy assignment operator and converting assignment operator +// - There is no specialization of converting assignment operators, which type +// is +// convertible is solely dependent on whether the scalar type is convertible +// +// In addition to the standard assignment, we also provide assignment operators +// with std and thrust +// +// +// [Casting operators] +// +// std::complex does not have casting operators. We define casting operators +// casting to std::complex and thrust::complex +// +// +// [Operator ""] +// +// std::complex has custom literals `i`, `if` and `il` defined in namespace +// `std::literals::complex_literals`. We define our own custom literals in the +// namespace `standalone::c10::complex_literals`. Our custom literals does not +// follow the same behavior as in std::complex, instead, we define _if, _id to +// construct float/double complex literals. +// +// +// [real() and imag()] +// +// In C++20, there are two overload of these functions, one it to return the +// real/imag, another is to set real/imag, they are both constexpr. We follow +// this design. +// +// +// [Operator +=,-=,*=,/=] +// +// Since C++20, these operators become constexpr. In our implementation, they +// are also constexpr. +// +// There are two types of such operators: operating with a real number, or +// operating with another complex number. For the operating with a real number, +// the generic template form has argument type `const T &`, while the overload +// for float/double/long double has `T`. We will follow the same type as +// float/double/long double in std. +// +// [Unary operator +-] +// +// Since C++20, they are constexpr. We also make them expr +// +// [Binary operators +-*/] +// +// Each operator has three versions (taking + as example): +// - complex + complex +// - complex + real +// - real + complex +// +// [Operator ==, !=] +// +// Each operator has three versions (taking == as example): +// - complex == complex +// - complex == real +// - real == complex +// +// Some of them are removed on C++20, but we decide to keep them +// +// [Operator <<, >>] +// +// These are implemented by casting to std::complex +// +// +// +// TODO(@zasdfgbnm): standalone::c10::complex is not +// currently supported, because: +// - lots of members and functions of standalone::c10::Half are not constexpr +// - thrust::complex only support float and double + +template +struct alignas(sizeof(T) * 2) complex { + using value_type = T; + + T real_ = T(0); + T imag_ = T(0); + + constexpr complex() = default; + STANDALONE_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) + : real_(re), imag_(im) {} + template + explicit constexpr complex(const std::complex& other) + : complex(other.real(), other.imag()) {} +#if defined(__CUDACC__) || defined(__HIPCC__) + template + explicit STANDALONE_HOST_DEVICE complex(const thrust::complex& other) + : real_(other.real()), imag_(other.imag()) {} +// NOTE can not be implemented as follow due to ROCm bug: +// explicit STANDALONE_HOST_DEVICE complex(const thrust::complex &other): +// complex(other.real(), other.imag()) {} +#endif + + // Use SFINAE to specialize casting constructor for + // standalone::c10::complex and standalone::c10::complex + template + STANDALONE_HOST_DEVICE explicit constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + template + STANDALONE_HOST_DEVICE constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + + constexpr complex& operator=(T re) { + real_ = re; + imag_ = 0; + return *this; + } + + constexpr complex& operator+=(T re) { + real_ += re; + return *this; + } + + constexpr complex& operator-=(T re) { + real_ -= re; + return *this; + } + + constexpr complex& operator*=(T re) { + real_ *= re; + imag_ *= re; + return *this; + } + + constexpr complex& operator/=(T re) { + real_ /= re; + imag_ /= re; + return *this; + } + + template + constexpr complex& operator=(const complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + + template + constexpr complex& operator+=(const complex& rhs) { + real_ += rhs.real(); + imag_ += rhs.imag(); + return *this; + } + + template + constexpr complex& operator-=(const complex& rhs) { + real_ -= rhs.real(); + imag_ -= rhs.imag(); + return *this; + } + + template + constexpr complex& operator*=(const complex& rhs) { + // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } + +#ifdef __APPLE__ +#define FORCE_INLINE_APPLE __attribute__((always_inline)) +#else +#define FORCE_INLINE_APPLE +#endif + template + constexpr FORCE_INLINE_APPLE complex& operator/=(const complex& rhs) + __ubsan_ignore_float_divide_by_zero__ { + // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i + // the calculation below follows numpy's complex division + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + +#if defined(__GNUC__) && !defined(__clang__) + // std::abs is already constexpr by gcc + auto abs_c = std::abs(c); + auto abs_d = std::abs(d); +#else + auto abs_c = c < 0 ? -c : c; + auto abs_d = d < 0 ? -d : d; +#endif + + if (abs_c >= abs_d) { + if (abs_c == U(0) && abs_d == U(0)) { + /* divide by zeros should yield a complex inf or nan */ + real_ = a / abs_c; + imag_ = b / abs_d; + } else { + auto rat = d / c; + auto scl = U(1.0) / (c + d * rat); + real_ = (a + b * rat) * scl; + imag_ = (b - a * rat) * scl; + } + } else { + auto rat = c / d; + auto scl = U(1.0) / (d + c * rat); + real_ = (a * rat + b) * scl; + imag_ = (b * rat - a) * scl; + } + return *this; + } +#undef FORCE_INLINE_APPLE + + template + constexpr complex& operator=(const std::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + STANDALONE_HOST_DEVICE complex& operator=(const thrust::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } +#endif + + template + explicit constexpr operator std::complex() const { + return std::complex(std::complex(real(), imag())); + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + STANDALONE_HOST_DEVICE explicit operator thrust::complex() const { + return static_cast>(thrust::complex(real(), imag())); + } +#endif + + // consistent with NumPy behavior + explicit constexpr operator bool() const { + return real() || imag(); + } + + STANDALONE_HOST_DEVICE constexpr T real() const { + return real_; + } + constexpr void real(T value) { + real_ = value; + } + STANDALONE_HOST_DEVICE constexpr T imag() const { + return imag_; + } + constexpr void imag(T value) { + imag_ = value; + } +}; + +namespace complex_literals { + +constexpr complex operator""_if(long double imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(long double imag) { + return complex(0.0, static_cast(imag)); +} + +constexpr complex operator""_if(unsigned long long imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(unsigned long long imag) { + return complex(0.0, static_cast(imag)); +} + +} // namespace complex_literals + +template +constexpr complex operator+(const complex& val) { + return val; +} + +template +constexpr complex operator-(const complex& val) { + return complex(-val.real(), -val.imag()); +} + +template +constexpr complex operator+(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const complex& lhs, const T& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const T& lhs, const complex& rhs) { + return complex(lhs + rhs.real(), rhs.imag()); +} + +template +constexpr complex operator-(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const complex& lhs, const T& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const T& lhs, const complex& rhs) { + complex result = -rhs; + return result += lhs; +} + +template +constexpr complex operator*(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const complex& lhs, const T& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const T& lhs, const complex& rhs) { + complex result = rhs; + return result *= lhs; +} + +template +constexpr complex operator/(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const complex& lhs, const T& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const T& lhs, const complex& rhs) { + complex result(lhs, T()); + return result /= rhs; +} + +// Define operators between integral scalars and standalone::c10::complex. +// std::complex does not support this when T is a floating-point number. This is +// useful because it saves a lot of "static_cast" when operate a complex and an +// integer. This makes the code both less verbose and potentially more +// efficient. +#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ + typename std::enable_if_t< \ + std::is_floating_point_v && std::is_integral_v, \ + int> = 0 + +template +constexpr standalone::c10::complex operator+( + const standalone::c10::complex& a, + const iT& b) { + return a + static_cast(b); +} + +template +constexpr standalone::c10::complex operator+( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) + b; +} + +template +constexpr standalone::c10::complex operator-( + const standalone::c10::complex& a, + const iT& b) { + return a - static_cast(b); +} + +template +constexpr standalone::c10::complex operator-( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) - b; +} + +template +constexpr standalone::c10::complex operator*( + const standalone::c10::complex& a, + const iT& b) { + return a * static_cast(b); +} + +template +constexpr standalone::c10::complex operator*( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) * b; +} + +template +constexpr standalone::c10::complex operator/( + const standalone::c10::complex& a, + const iT& b) { + return a / static_cast(b); +} + +template +constexpr standalone::c10::complex operator/( + const iT& a, + const standalone::c10::complex& b) { + return static_cast(a) / b; +} + +#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION + +template +constexpr bool operator==(const complex& lhs, const complex& rhs) { + return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); +} + +template +constexpr bool operator==(const complex& lhs, const T& rhs) { + return (lhs.real() == rhs) && (lhs.imag() == T()); +} + +template +constexpr bool operator==(const T& lhs, const complex& rhs) { + return (lhs == rhs.real()) && (T() == rhs.imag()); +} + +template +constexpr bool operator!=(const complex& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const complex& lhs, const T& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const T& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const complex& x) { + return (os << static_cast>(x)); +} + +template +std::basic_istream& operator>>( + std::basic_istream& is, + complex& x) { + std::complex tmp; + is >> tmp; + x = tmp; + return is; +} + +} // namespace standalone::c10 + +// std functions +// +// The implementation of these functions also follow the design of C++20 + +namespace std { + +template +constexpr T real(const standalone::c10::complex& z) { + return z.real(); +} + +template +constexpr T imag(const standalone::c10::complex& z) { + return z.imag(); +} + +template +STANDALONE_HOST_DEVICE T abs(const standalone::c10::complex& z) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return thrust::abs(static_cast>(z)); +#else + return std::abs(static_cast>(z)); +#endif +} + +#if defined(USE_ROCM) +#define ROCm_Bug(x) +#else +#define ROCm_Bug(x) x +#endif + +template +STANDALONE_HOST_DEVICE T arg(const standalone::c10::complex& z) { + return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); +} + +#undef ROCm_Bug + +template +constexpr T norm(const standalone::c10::complex& z) { + return z.real() * z.real() + z.imag() * z.imag(); +} + +// For std::conj, there are other versions of it: +// constexpr std::complex conj( float z ); +// template< class DoubleOrInteger > +// constexpr std::complex conj( DoubleOrInteger z ); +// constexpr std::complex conj( long double z ); +// These are not implemented +// TODO(@zasdfgbnm): implement them as standalone::c10::conj +template +constexpr standalone::c10::complex conj( + const standalone::c10::complex& z) { + return standalone::c10::complex(z.real(), -z.imag()); +} + +// Thrust does not have complex --> complex version of thrust::proj, +// so this function is not implemented at standalone right now. +// TODO(@zasdfgbnm): implement it by ourselves + +// There is no standalone version of std::polar, because std::polar always +// returns std::complex. Use standalone::c10::polar instead; + +} // namespace std + +namespace standalone::c10 { + +template +STANDALONE_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::polar(r, theta)); +#else + // std::polar() requires r >= 0, so spell out the explicit implementation to + // avoid a branch. + return complex(r * std::cos(theta), r * std::sin(theta)); +#endif +} + +template <> +struct alignas(4) complex { + Half real_; + Half imag_; + + // Constructors + complex() = default; + // Half constructor is not constexpr so the following constructor can't + // be constexpr + STANDALONE_HOST_DEVICE explicit inline complex( + const Half& real, + const Half& imag) + : real_(real), imag_(imag) {} + STANDALONE_HOST_DEVICE inline complex( + const standalone::c10::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + // Conversion operator + inline STANDALONE_HOST_DEVICE operator standalone::c10::complex() + const { + return {real_, imag_}; + } + + constexpr STANDALONE_HOST_DEVICE Half real() const { + return real_; + } + constexpr STANDALONE_HOST_DEVICE Half imag() const { + return imag_; + } + + STANDALONE_HOST_DEVICE complex& operator+=(const complex& other) { + real_ = static_cast(real_) + static_cast(other.real_); + imag_ = static_cast(imag_) + static_cast(other.imag_); + return *this; + } + + STANDALONE_HOST_DEVICE complex& operator-=(const complex& other) { + real_ = static_cast(real_) - static_cast(other.real_); + imag_ = static_cast(imag_) - static_cast(other.imag_); + return *this; + } + + STANDALONE_HOST_DEVICE complex& operator*=(const complex& other) { + auto a = static_cast(real_); + auto b = static_cast(imag_); + auto c = static_cast(other.real()); + auto d = static_cast(other.imag()); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } +}; + +} // namespace standalone::c10 + +STANDALONE_CLANG_DIAGNOSTIC_POP() + +#define STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H +// math functions are included in a separate file +#include // IWYU pragma: keep +// utilities for complex types +#include // IWYU pragma: keep +#undef STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H diff --git a/backends/aoti/slim/c10/util/complex_math.h b/backends/aoti/slim/c10/util/complex_math.h new file mode 100644 index 00000000000..56fc84fe90b --- /dev/null +++ b/backends/aoti/slim/c10/util/complex_math.h @@ -0,0 +1,500 @@ +#if !defined(STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "standalone/c10/util/complex_math.h is not meant to be individually included. Include standalone/c10/util/complex.h instead." +#endif + +#include + +namespace standalone::c10::complex_math { + +// Exponential functions + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex exp( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::exp(static_cast>(x))); +#else + return static_cast>( + std::exp(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log(static_cast>(x))); +#else + return static_cast>( + std::log(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log10( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log10(static_cast>(x))); +#else + return static_cast>( + std::log10(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log2( + const standalone::c10::complex& x) { + const standalone::c10::complex log2 = + standalone::c10::complex(::log(2.0), 0.0); + return standalone::c10::complex_math::log(x) / log2; +} + +// Power functions +// +#if defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) +namespace _detail { +template +standalone::c10::complex compute_csqrt( + const standalone::c10::complex& z) { + constexpr auto half = T(.5); + + // Trust standard library to correctly handle infs and NaNs + if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) || + std::isnan(z.imag())) { + return static_cast>( + std::sqrt(static_cast>(z))); + } + + // Special case for square root of pure imaginary values + if (z.real() == T(0)) { + if (z.imag() == T(0)) { + return standalone::c10::complex(T(0), z.imag()); + } + auto v = std::sqrt(half * std::abs(z.imag())); + return standalone::c10::complex(v, std::copysign(v, z.imag())); + } + + // At this point, z is non-zero and finite + if (z.real() >= 0.0) { + auto t = std::sqrt((z.real() + std::abs(z)) * half); + return standalone::c10::complex(t, half * (z.imag() / t)); + } + + auto t = std::sqrt((-z.real() + std::abs(z)) * half); + return standalone::c10::complex( + half * std::abs(z.imag() / t), std::copysign(t, z.imag())); +} + +// Compute complex arccosine using formula from W. Kahan +// "Branch Cuts for Complex Elementary Functions" 1986 paper: +// cacos(z).re = 2*atan2(sqrt(1-z).re(), sqrt(1+z).re()) +// cacos(z).im = asinh((sqrt(conj(1+z))*sqrt(1-z)).im()) +template +standalone::c10::complex compute_cacos( + const standalone::c10::complex& z) { + auto constexpr one = T(1); + // Trust standard library to correctly handle infs and NaNs + if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) || + std::isnan(z.imag())) { + return static_cast>( + std::acos(static_cast>(z))); + } + auto a = + compute_csqrt(standalone::c10::complex(one - z.real(), -z.imag())); + auto b = compute_csqrt(standalone::c10::complex(one + z.real(), z.imag())); + auto c = + compute_csqrt(standalone::c10::complex(one + z.real(), -z.imag())); + auto r = T(2) * std::atan2(a.real(), b.real()); + // Explicitly unroll (a*c).imag() + auto i = std::asinh(a.real() * c.imag() + a.imag() * c.real()); + return standalone::c10::complex(r, i); +} + +inline standalone::c10::complex sqrt( + const standalone::c10::complex& in) { + return compute_csqrt(in); +} + +inline standalone::c10::complex sqrt( + const standalone::c10::complex& in) { + return compute_csqrt(in); +} + +inline standalone::c10::complex acos( + const standalone::c10::complex& in) { + return compute_cacos(in); +} + +inline standalone::c10::complex acos( + const standalone::c10::complex& in) { + return compute_cacos(in); +} +} // namespace _detail +#endif + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex sqrt( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sqrt(static_cast>(x))); +#elif !( \ + defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))) + return static_cast>( + std::sqrt(static_cast>(x))); +#else + return _detail::sqrt(x); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const T& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const T& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const standalone::c10::complex& x, + const U& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex pow( + const T& x, + const standalone::c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +// Trigonometric functions + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex sin( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sin(static_cast>(x))); +#else + return static_cast>( + std::sin(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex cos( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cos(static_cast>(x))); +#else + return static_cast>( + std::cos(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex tan( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tan(static_cast>(x))); +#else + return static_cast>( + std::tan(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex asin( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asin(static_cast>(x))); +#else + return static_cast>( + std::asin(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex acos( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acos(static_cast>(x))); +#elif !defined(_LIBCPP_VERSION) + return static_cast>( + std::acos(static_cast>(x))); +#else + return _detail::acos(x); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex atan( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atan(static_cast>(x))); +#else + return static_cast>( + std::atan(static_cast>(x))); +#endif +} + +// Hyperbolic functions + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex sinh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sinh(static_cast>(x))); +#else + return static_cast>( + std::sinh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex cosh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cosh(static_cast>(x))); +#else + return static_cast>( + std::cosh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex tanh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tanh(static_cast>(x))); +#else + return static_cast>( + std::tanh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex asinh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asinh(static_cast>(x))); +#else + return static_cast>( + std::asinh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex acosh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acosh(static_cast>(x))); +#else + return static_cast>( + std::acosh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex atanh( + const standalone::c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atanh(static_cast>(x))); +#else + return static_cast>( + std::atanh(static_cast>(x))); +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex log1p( + const standalone::c10::complex& z) { +#if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \ + defined(__HIPCC__) + // For Mac, the new implementation yielded a high relative error. Falling back + // to the old version for now. + // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + // For CUDA we also use this one, as thrust::log(thrust::complex) takes + // *forever* to compile + + // log1p(z) = log(1 + z) + // Let's define 1 + z = r * e ^ (i * a), then we have + // log(r * e ^ (i * a)) = log(r) + i * a + // With z = x + iy, the term r can be written as + // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 + // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 + // So, log(r) is + // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) + // = 0.5 * log1p(x * (x + 2) + y ^ 2) + // we need to use the expression only on certain condition to avoid overflow + // and underflow from `(x * (x + 2) + y ^ 2)` + T x = z.real(); + T y = z.imag(); + T zabs = std::abs(z); + T theta = std::atan2(y, x + T(1)); + if (zabs < 0.5) { + T r = x * (T(2) + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {T(0.5) * std::log1p(r), theta}; + } else { + T z0 = std::hypot(x + 1, y); + return {std::log(z0), theta}; + } +#else + // CPU path + // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + standalone::c10::complex u = z + T(1); + if (u == T(1)) { + return z; + } else { + auto log_u = log(u); + if (u - T(1) == z) { + return log_u; + } + return log_u * (z / (u - T(1))); + } +#endif +} + +template +STANDALONE_HOST_DEVICE inline standalone::c10::complex expm1( + const standalone::c10::complex& z) { + // expm1(z) = exp(z) - 1 + // Define z = x + i * y + // f = e ^ (x + i * y) - 1 + // = e ^ x * e ^ (i * y) - 1 + // = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y)) + // = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y) + // = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y) + T x = z.real(); + T y = z.imag(); + T a = std::sin(y / 2); + T er = std::expm1(x) * std::cos(y) - T(2) * a * a; + T ei = std::exp(x) * std::sin(y); + return {er, ei}; +} + +} // namespace standalone::c10::complex_math + +using standalone::c10::complex_math::acos; +using standalone::c10::complex_math::acosh; +using standalone::c10::complex_math::asin; +using standalone::c10::complex_math::asinh; +using standalone::c10::complex_math::atan; +using standalone::c10::complex_math::atanh; +using standalone::c10::complex_math::cos; +using standalone::c10::complex_math::cosh; +using standalone::c10::complex_math::exp; +using standalone::c10::complex_math::expm1; +using standalone::c10::complex_math::log; +using standalone::c10::complex_math::log10; +using standalone::c10::complex_math::log1p; +using standalone::c10::complex_math::log2; +using standalone::c10::complex_math::pow; +using standalone::c10::complex_math::sin; +using standalone::c10::complex_math::sinh; +using standalone::c10::complex_math::sqrt; +using standalone::c10::complex_math::tan; +using standalone::c10::complex_math::tanh; + +namespace std { + +using standalone::c10::complex_math::acos; +using standalone::c10::complex_math::acosh; +using standalone::c10::complex_math::asin; +using standalone::c10::complex_math::asinh; +using standalone::c10::complex_math::atan; +using standalone::c10::complex_math::atanh; +using standalone::c10::complex_math::cos; +using standalone::c10::complex_math::cosh; +using standalone::c10::complex_math::exp; +using standalone::c10::complex_math::expm1; +using standalone::c10::complex_math::log; +using standalone::c10::complex_math::log10; +using standalone::c10::complex_math::log1p; +using standalone::c10::complex_math::log2; +using standalone::c10::complex_math::pow; +using standalone::c10::complex_math::sin; +using standalone::c10::complex_math::sinh; +using standalone::c10::complex_math::sqrt; +using standalone::c10::complex_math::tan; +using standalone::c10::complex_math::tanh; + +} // namespace std diff --git a/backends/aoti/slim/c10/util/complex_utils.h b/backends/aoti/slim/c10/util/complex_utils.h new file mode 100644 index 00000000000..5b29406a186 --- /dev/null +++ b/backends/aoti/slim/c10/util/complex_utils.h @@ -0,0 +1,46 @@ +#if !defined(STANDALONE_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "standalone/c10/util/complex_utils.h is not meant to be individually included. Include standalone/c10/util/complex.h instead." +#endif + +#include + +namespace standalone::c10 { + +template +struct is_complex : public std::false_type {}; + +template +struct is_complex> : public std::true_type {}; + +template +struct is_complex> : public std::true_type {}; + +// Extract double from std::complex; is identity otherwise +// TODO: Write in more idiomatic C++17 +template +struct scalar_value_type { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; + +} // namespace standalone::c10 + +namespace std { + +template +class numeric_limits> : public numeric_limits {}; + +template +bool isnan(const standalone::c10::complex& v) { + return std::isnan(v.real()) || std::isnan(v.imag()); +} + +} // namespace std diff --git a/backends/aoti/slim/c10/util/copysign.h b/backends/aoti/slim/c10/util/copysign.h new file mode 100644 index 00000000000..1012934049c --- /dev/null +++ b/backends/aoti/slim/c10/util/copysign.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +namespace standalone::c10 { + +// Note: Explicit implementation of copysign for Half and BFloat16 +// is needed to workaround g++-7/8 crash on aarch64, but also makes +// copysign faster for the half-precision types +template +inline auto copysign(const T& a, const U& b) { + return std::copysign(a, b); +} + +// Implement copysign for half precision floats using bit ops +// Sign is the most significant bit for both half and bfloat16 types +inline Half copysign(Half a, Half b) { + return Half((a.x & 0x7fff) | (b.x & 0x8000), Half::from_bits()); +} + +inline BFloat16 copysign(BFloat16 a, BFloat16 b) { + return BFloat16((a.x & 0x7fff) | (b.x & 0x8000), BFloat16::from_bits()); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/floating_point_utils.h b/backends/aoti/slim/c10/util/floating_point_utils.h new file mode 100644 index 00000000000..259cb93b0a5 --- /dev/null +++ b/backends/aoti/slim/c10/util/floating_point_utils.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +namespace standalone::c10::detail { + +STANDALONE_HOST_DEVICE inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __uint_as_float((unsigned int)w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#else + return standalone::c10::bit_cast(w); +#endif +} + +STANDALONE_HOST_DEVICE inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return (uint32_t)__float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#else + return standalone::c10::bit_cast(f); +#endif +} + +} // namespace standalone::c10::detail diff --git a/backends/aoti/slim/c10/util/generic_math.h b/backends/aoti/slim/c10/util/generic_math.h new file mode 100644 index 00000000000..00bb4265d9d --- /dev/null +++ b/backends/aoti/slim/c10/util/generic_math.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include + +#if defined(__CUDA_ARCH__) +#include +#define STANDALONE_COMPAT_COPYSIGN standalone::c10::cuda::compat::copysign +// TODO: rocm is not supported yet +// #elif defined(__HIPCC__) +// #include +// #define STANDALONE_COMPAT_COPYSIGN standalone::c10::hip::compat::copysign +#else +#include +#define STANDALONE_COMPAT_COPYSIGN standalone::c10::copysign +#endif + +// The functions in this file should be header-only as it is used under +// ABI-compatibility mode. + +namespace standalone::c10 { + +// NOTE: [Floor Division in Python] +// Python's __floordiv__ operator is more complicated than just floor(a / b). +// It aims to maintain the property: a == (a // b) * b + remainder(a, b) +// which can otherwise fail due to rounding errors in the remainder. +// So, instead it is calculated as: a // b = (a - remainder(a, b)) / b +// With some additional fix-ups added to the result. +// +// For reference, see CPython's implementation: +// https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + +template +inline STANDALONE_HOST_DEVICE scalar_t div_floor_floating( + scalar_t a, + scalar_t b) __ubsan_ignore_float_divide_by_zero__ { + if (STANDALONE_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return a / b; + } + + auto mod = std::fmod(a, b); + auto div = (a - mod) / b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = STANDALONE_COMPAT_COPYSIGN(scalar_t(0), a / b); + } + return floordiv; +} + +template +inline STANDALONE_HOST_DEVICE scalar_t +div_floor_integer(scalar_t a, scalar_t b) { + if (standalone::c10::signs_differ(a, b)) { + // Subtracts one from the results of truncation division if the + // divisor and dividend have different sign(bit)s and the remainder of + // the division is nonzero + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + return a / b; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline STANDALONE_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) + __ubsan_ignore_float_divide_by_zero__ { + if (STANDALONE_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return std::fmod(a, b); + } + + auto mod = std::fmod(a, b); + if (mod == 0) { + mod = STANDALONE_COMPAT_COPYSIGN(scalar_t(0), b); + } else if ((b < 0) != (mod < 0)) { + mod += b; + } + return mod; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline STANDALONE_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) { + auto mod = a % b; + if (mod != 0 && (b > 0) != (mod > 0)) { + mod += b; + } + return mod; +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/irange.h b/backends/aoti/slim/c10/util/irange.h new file mode 100644 index 00000000000..0d10f373a04 --- /dev/null +++ b/backends/aoti/slim/c10/util/irange.h @@ -0,0 +1,123 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include + +namespace standalone::c10 { + +namespace detail { + +template < + typename I, + bool one_sided = false, + std::enable_if_t, int> = 0> +struct integer_iterator { + using iterator_category = std::input_iterator_tag; + using value_type = I; + using difference_type = std::ptrdiff_t; + using pointer = I*; + using reference = I&; + + explicit constexpr integer_iterator(I value) : value(value) {} + + constexpr I operator*() const { + return value; + } + + constexpr I const* operator->() const { + return &value; + } + + constexpr integer_iterator& operator++() { + ++value; + return *this; + } + + constexpr integer_iterator operator++(int) { + const auto copy = *this; + ++*this; + return copy; + } + + constexpr bool operator==(const integer_iterator& other) const { + if constexpr (one_sided) { + // Range-for loops' end test is `begin != end`, not `begin < + // end`. To handle `standalone::c10::irange(n)` where n < 0 (which + // should be empty), we just make `begin != end` fail whenever `end` is + // negative. + return is_negative(other.value) || value == other.value; + } else { + return value == other.value; + } + // Suppress "warning: missing return statement at end of non-void function" + // which Nvidia's Robert Crovella confirms is an NVCC compiler error + // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27 + // `__builtin_unreachable();` would be best here, but it's not + // available with all compilers. So we instead return an arbitrary + // value trusting that this line will, in fact, never be reached. + return false; // Horrible hack + } + + constexpr bool operator!=(const integer_iterator& other) const { + return !(*this == other); + } + + protected: + I value; +}; + +} // namespace detail + +template < + typename I, + bool one_sided = false, + std::enable_if_t, bool> = true> +struct integer_range { + public: + constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {} + using iterator = detail::integer_iterator; + constexpr iterator begin() const { + return begin_; + } + constexpr iterator end() const { + return end_; + } + + private: + iterator begin_; + iterator end_; +}; + +/// Creates an integer range for the half-open interval [begin, end) +/// If end<=begin, then the range is empty. +/// The range has the type of the `end` integer; `begin` integer is +/// cast to this type. +template < + typename Integer1, + typename Integer2, + std::enable_if_t, bool> = true, + std::enable_if_t, bool> = true> +constexpr integer_range irange(Integer1 begin, Integer2 end) { + // If end<=begin then the range is empty; we can achieve this effect by + // choosing the larger of {begin, end} as the loop terminator + return { + static_cast(begin), + std::max(static_cast(begin), end)}; +} + +/// Creates an integer range for the half-open interval [0, end) +/// If end<=begin, then the range is empty +template < + typename Integer, + std::enable_if_t, bool> = true> +constexpr integer_range irange(Integer end) { + return {Integer(), end}; +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/llvmMathExtras.h b/backends/aoti/slim/c10/util/llvmMathExtras.h new file mode 100644 index 00000000000..0b4f92c44c6 --- /dev/null +++ b/backends/aoti/slim/c10/util/llvmMathExtras.h @@ -0,0 +1,899 @@ +//===-- llvm/Support/MathExtras.h - Useful math functions -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some functions that are useful for math stuff. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __ANDROID_NDK__ +#include +#endif + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + +#ifndef LLVM_GNUC_PREREQ +#if defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) +#define LLVM_GNUC_PREREQ(maj, min, patch) \ + ((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) + __GNUC_PATCHLEVEL__ >= \ + ((maj) << 20) + ((min) << 10) + (patch)) +#elif defined(__GNUC__) && defined(__GNUC_MINOR__) +#define LLVM_GNUC_PREREQ(maj, min, patch) \ + ((__GNUC__ << 20) + (__GNUC_MINOR__ << 10) >= ((maj) << 20) + ((min) << 10)) +#else +#define LLVM_GNUC_PREREQ(maj, min, patch) 0 +#endif +#endif + +#ifdef _MSC_VER +// Declare these intrinsics manually rather including intrin.h. It's very +// expensive, and MathExtras.h is popular. +// #include +extern "C" { +unsigned char _BitScanForward(unsigned long* _Index, unsigned long _Mask); +unsigned char _BitScanForward64(unsigned long* _Index, unsigned __int64 _Mask); +unsigned char _BitScanReverse(unsigned long* _Index, unsigned long _Mask); +unsigned char _BitScanReverse64(unsigned long* _Index, unsigned __int64 _Mask); +} +#endif + +namespace standalone::c10::llvm { +/// The behavior an operation has on an input of 0. +enum ZeroBehavior { + /// The returned value is undefined. + ZB_Undefined, + /// The returned value is numeric_limits::max() + ZB_Max, + /// The returned value is numeric_limits::digits + ZB_Width +}; + +namespace detail { +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior) { + if (!Val) + return std::numeric_limits::digits; + if (Val & 0x1) + return 0; + + // Bisection method. + std::size_t ZeroBits = 0; + T Shift = std::numeric_limits::digits >> 1; + T Mask = std::numeric_limits::max() >> Shift; + while (Shift) { + if ((Val & Mask) == 0) { + Val >>= Shift; + ZeroBits |= Shift; + } + Shift >>= 1; + Mask >>= Shift; + } + return ZeroBits; + } +}; + +#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER) +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 32; + +#if __has_builtin(__builtin_ctz) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_ctz(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanForward(&Index, Val); + return Index; +#endif + } +}; + +#if !defined(_MSC_VER) || defined(_M_X64) +template +struct TrailingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 64; + +#if __has_builtin(__builtin_ctzll) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_ctzll(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanForward64(&Index, Val); + return Index; +#endif + } +}; +#endif +#endif +} // namespace detail + +/// Count number of 0's from the least significant bit to the most +/// stopping at the first 1. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are +/// valid arguments. +template +std::size_t countTrailingZeros(T Val, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return llvm::detail::TrailingZerosCounter::count(Val, ZB); +} + +namespace detail { +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior) { + if (!Val) + return std::numeric_limits::digits; + + // Bisection method. + std::size_t ZeroBits = 0; + for (T Shift = std::numeric_limits::digits >> 1; Shift; Shift >>= 1) { + T Tmp = Val >> Shift; + if (Tmp) + Val = Tmp; + else + ZeroBits |= Shift; + } + return ZeroBits; + } +}; + +#if (defined(__GNUC__) && __GNUC__ >= 4) || defined(_MSC_VER) +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 32; + +#if __has_builtin(__builtin_clz) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_clz(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanReverse(&Index, Val); + return Index ^ 31; +#endif + } +}; + +#if !defined(_MSC_VER) || defined(_M_X64) +template +struct LeadingZerosCounter { + static std::size_t count(T Val, ZeroBehavior ZB) { + if (ZB != ZB_Undefined && Val == 0) + return 64; + +#if __has_builtin(__builtin_clzll) || LLVM_GNUC_PREREQ(4, 0, 0) + return __builtin_clzll(Val); +#elif defined(_MSC_VER) + unsigned long Index; + _BitScanReverse64(&Index, Val); + return Index ^ 63; +#endif + } +}; +#endif +#endif +} // namespace detail + +/// Count number of 0's from the most significant bit to the least +/// stopping at the first 1. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Width and ZB_Undefined are +/// valid arguments. +template +std::size_t countLeadingZeros(T Val, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return llvm::detail::LeadingZerosCounter::count(Val, ZB); +} + +/// Get the index of the first set bit starting from the least +/// significant bit. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are +/// valid arguments. +template +T findFirstSet(T Val, ZeroBehavior ZB = ZB_Max) { + if (ZB == ZB_Max && Val == 0) + return std::numeric_limits::max(); + + return countTrailingZeros(Val, ZB_Undefined); +} + +/// Create a bitmask with the N right-most bits set to 1, and all other +/// bits set to 0. Only unsigned types are allowed. +template +T maskTrailingOnes(unsigned N) { + static_assert(std::is_unsigned_v, "Invalid type!"); + const unsigned Bits = CHAR_BIT * sizeof(T); + assert(N <= Bits && "Invalid bit index"); + return N == 0 ? 0 : (T(-1) >> (Bits - N)); +} + +/// Create a bitmask with the N left-most bits set to 1, and all other +/// bits set to 0. Only unsigned types are allowed. +template +T maskLeadingOnes(unsigned N) { + return ~maskTrailingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Create a bitmask with the N right-most bits set to 0, and all other +/// bits set to 1. Only unsigned types are allowed. +template +T maskTrailingZeros(unsigned N) { + return maskLeadingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Create a bitmask with the N left-most bits set to 0, and all other +/// bits set to 1. Only unsigned types are allowed. +template +T maskLeadingZeros(unsigned N) { + return maskTrailingOnes(CHAR_BIT * sizeof(T) - N); +} + +/// Get the index of the last set bit starting from the least +/// significant bit. +/// +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of 0. Only ZB_Max and ZB_Undefined are +/// valid arguments. +template +T findLastSet(T Val, ZeroBehavior ZB = ZB_Max) { + if (ZB == ZB_Max && Val == 0) + return std::numeric_limits::max(); + + // Use ^ instead of - because both gcc and llvm can remove the associated ^ + // in the __builtin_clz intrinsic on x86. + return countLeadingZeros(Val, ZB_Undefined) ^ + (std::numeric_limits::digits - 1); +} + +/// Macro compressed bit reversal table for 256 bits. +/// +/// http://graphics.stanford.edu/~seander/bithacks.html#BitReverseTable +/// NOLINTNEXTLINE(*c-arrays*) +static constexpr unsigned char BitReverseTable256[256] = { +#define R2(n) n, n + 2 * 64, n + 1 * 64, n + 3 * 64 +#define R4(n) R2(n), R2(n + 2 * 16), R2(n + 1 * 16), R2(n + 3 * 16) +#define R6(n) R4(n), R4(n + 2 * 4), R4(n + 1 * 4), R4(n + 3 * 4) + R6(0), + R6(2), + R6(1), + R6(3) +#undef R2 +#undef R4 +#undef R6 +}; + +/// Reverse the bits in \p Val. +template +T reverseBits(T Val) { + // NOLINTNEXTLINE(*c-arrays*) + unsigned char in[sizeof(Val)]; + // NOLINTNEXTLINE(*c-arrays*) + unsigned char out[sizeof(Val)]; + std::memcpy(in, &Val, sizeof(Val)); + for (unsigned i = 0; i < sizeof(Val); ++i) + out[(sizeof(Val) - i) - 1] = BitReverseTable256[in[i]]; + std::memcpy(&Val, out, sizeof(Val)); + return Val; +} + +// NOTE: The following support functions use the _32/_64 extensions instead of +// type overloading so that signed and unsigned integers can be used without +// ambiguity. + +/// Return the high 32 bits of a 64 bit value. +constexpr inline uint32_t Hi_32(uint64_t Value) { + return static_cast(Value >> 32); +} + +/// Return the low 32 bits of a 64 bit value. +constexpr inline uint32_t Lo_32(uint64_t Value) { + return static_cast(Value); +} + +/// Make a 64-bit integer from a high / low pair of 32-bit integers. +constexpr inline uint64_t Make_64(uint32_t High, uint32_t Low) { + return ((uint64_t)High << 32) | (uint64_t)Low; +} + +/// Checks if an integer fits into the given bit width. +template +constexpr inline bool isInt(int64_t x) { + return N >= 64 || + (-(INT64_C(1) << (N - 1)) <= x && x < (INT64_C(1) << (N - 1))); +} +// Template specializations to get better code for common cases. +template <> +constexpr inline bool isInt<8>(int64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isInt<16>(int64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isInt<32>(int64_t x) { + return static_cast(x) == x; +} + +/// Checks if a signed integer is an N bit number shifted left by S. +template +constexpr inline bool isShiftedInt(int64_t x) { + static_assert( + N > 0, "isShiftedInt<0> doesn't make sense (refers to a 0-bit number."); + static_assert(N + S <= 64, "isShiftedInt with N + S > 64 is too wide."); + return isInt(x) && (x % (UINT64_C(1) << S) == 0); +} + +/// Checks if an unsigned integer fits into the given bit width. +/// +/// This is written as two functions rather than as simply +/// +/// return N >= 64 || X < (UINT64_C(1) << N); +/// +/// to keep MSVC from (incorrectly) warning on isUInt<64> that we're shifting +/// left too many places. +template +constexpr inline std::enable_if_t<(N < 64), bool> isUInt(uint64_t X) { + static_assert(N > 0, "isUInt<0> doesn't make sense"); + return X < (UINT64_C(1) << (N)); +} +template +constexpr inline std::enable_if_t= 64, bool> isUInt(uint64_t /*X*/) { + return true; +} + +// Template specializations to get better code for common cases. +template <> +constexpr inline bool isUInt<8>(uint64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isUInt<16>(uint64_t x) { + return static_cast(x) == x; +} +template <> +constexpr inline bool isUInt<32>(uint64_t x) { + return static_cast(x) == x; +} + +/// Checks if a unsigned integer is an N bit number shifted left by S. +template +constexpr inline bool isShiftedUInt(uint64_t x) { + static_assert( + N > 0, "isShiftedUInt<0> doesn't make sense (refers to a 0-bit number)"); + static_assert( + N + S <= 64, "isShiftedUInt with N + S > 64 is too wide."); + // Per the two static_asserts above, S must be strictly less than 64. So + // 1 << S is not undefined behavior. + return isUInt(x) && (x % (UINT64_C(1) << S) == 0); +} + +/// Gets the maximum value for a N-bit unsigned integer. +inline uint64_t maxUIntN(uint64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + + // uint64_t(1) << 64 is undefined behavior, so we can't do + // (uint64_t(1) << N) - 1 + // without checking first that N != 64. But this works and doesn't have a + // branch. + return UINT64_MAX >> (64 - N); +} + +// Ignore the false warning "Arithmetic overflow" for MSVC +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#endif + +/// Gets the minimum value for a N-bit signed integer. +inline int64_t minIntN(int64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + // NOLINTNEXTLINE(*-narrowing-conversions) + return -(UINT64_C(1) << (N - 1)); +} + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/// Gets the maximum value for a N-bit signed integer. +inline int64_t maxIntN(int64_t N) { + assert(N > 0 && N <= 64 && "integer width out of range"); + + // This relies on two's complement wraparound when N == 64, so we convert to + // int64_t only at the very end to avoid UB. + // NOLINTNEXTLINE(*-narrowing-conversions) + return (UINT64_C(1) << (N - 1)) - 1; +} + +/// Checks if an unsigned integer fits into the given (dynamic) bit width. +inline bool isUIntN(unsigned N, uint64_t x) { + return N >= 64 || x <= maxUIntN(N); +} + +/// Checks if an signed integer fits into the given (dynamic) bit width. +inline bool isIntN(unsigned N, int64_t x) { + return N >= 64 || (minIntN(N) <= x && x <= maxIntN(N)); +} + +/// Return true if the argument is a non-empty sequence of ones starting at the +/// least significant bit with the remainder zero (32 bit version). +/// Ex. isMask_32(0x0000FFFFU) == true. +constexpr inline bool isMask_32(uint32_t Value) { + return Value && ((Value + 1) & Value) == 0; +} + +/// Return true if the argument is a non-empty sequence of ones starting at the +/// least significant bit with the remainder zero (64 bit version). +constexpr inline bool isMask_64(uint64_t Value) { + return Value && ((Value + 1) & Value) == 0; +} + +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true. +constexpr inline bool isShiftedMask_32(uint32_t Value) { + return Value && isMask_32((Value - 1) | Value); +} + +/// Return true if the argument contains a non-empty sequence of ones with the +/// remainder zero (64 bit version.) +constexpr inline bool isShiftedMask_64(uint64_t Value) { + return Value && isMask_64((Value - 1) | Value); +} + +/// Return true if the argument is a power of two > 0. +/// Ex. isPowerOf2_32(0x00100000U) == true (32 bit edition.) +constexpr inline bool isPowerOf2_32(uint32_t Value) { + return Value && !(Value & (Value - 1)); +} + +/// Return true if the argument is a power of two > 0 (64 bit edition.) +constexpr inline bool isPowerOf2_64(uint64_t Value) { + return Value && !(Value & (Value - 1)); +} + +/// Count the number of ones from the most significant bit to the first +/// zero bit. +/// +/// Ex. countLeadingOnes(0xFF0FFF00) == 8. +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of all ones. Only ZB_Width and +/// ZB_Undefined are valid arguments. +template +std::size_t countLeadingOnes(T Value, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return countLeadingZeros(~Value, ZB); +} + +/// Count the number of ones from the least significant bit to the first +/// zero bit. +/// +/// Ex. countTrailingOnes(0x00FF00FF) == 8. +/// Only unsigned integral types are allowed. +/// +/// \param ZB the behavior on an input of all ones. Only ZB_Width and +/// ZB_Undefined are valid arguments. +template +std::size_t countTrailingOnes(T Value, ZeroBehavior ZB = ZB_Width) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return countTrailingZeros(~Value, ZB); +} + +namespace detail { +template +struct PopulationCounter { + static unsigned count(T Value) { + // Generic version, forward to 32 bits. + static_assert(SizeOfT <= 4, "Not implemented!"); +#if defined(__GNUC__) && __GNUC__ >= 4 + return __builtin_popcount(Value); +#else + uint32_t v = Value; + v = v - ((v >> 1) & 0x55555555); + v = (v & 0x33333333) + ((v >> 2) & 0x33333333); + return ((v + (v >> 4) & 0xF0F0F0F) * 0x1010101) >> 24; +#endif + } +}; + +template +struct PopulationCounter { + static unsigned count(T Value) { +#if defined(__GNUC__) && __GNUC__ >= 4 + return __builtin_popcountll(Value); +#else + uint64_t v = Value; + v = v - ((v >> 1) & 0x5555555555555555ULL); + v = (v & 0x3333333333333333ULL) + ((v >> 2) & 0x3333333333333333ULL); + v = (v + (v >> 4)) & 0x0F0F0F0F0F0F0F0FULL; + return unsigned((uint64_t)(v * 0x0101010101010101ULL) >> 56); +#endif + } +}; +} // namespace detail + +/// Count the number of set bits in a value. +/// Ex. countPopulation(0xF000F000) = 8 +/// Returns 0 if the word is zero. +template +inline unsigned countPopulation(T Value) { + static_assert( + std::numeric_limits::is_integer && !std::numeric_limits::is_signed, + "Only unsigned integral types are allowed."); + return detail::PopulationCounter::count(Value); +} + +/// Return the log base 2 of the specified value. +inline double Log2(double Value) { +#if defined(__ANDROID_API__) && __ANDROID_API__ < 18 + return __builtin_log(Value) / __builtin_log(2.0); +#else + return log2(Value); +#endif +} + +/// Return the floor log base 2 of the specified value, -1 if the value is zero. +/// (32 bit edition.) +/// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2 +inline unsigned Log2_32(uint32_t Value) { + return static_cast(31 - countLeadingZeros(Value)); +} + +/// Return the floor log base 2 of the specified value, -1 if the value is zero. +/// (64 bit edition.) +inline unsigned Log2_64(uint64_t Value) { + return static_cast(63 - countLeadingZeros(Value)); +} + +/// Return the ceil log base 2 of the specified value, 32 if the value is zero. +/// (32 bit edition). +/// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3 +inline unsigned Log2_32_Ceil(uint32_t Value) { + return static_cast(32 - countLeadingZeros(Value - 1)); +} + +/// Return the ceil log base 2 of the specified value, 64 if the value is zero. +/// (64 bit edition.) +inline unsigned Log2_64_Ceil(uint64_t Value) { + return static_cast(64 - countLeadingZeros(Value - 1)); +} + +/// Return the greatest common divisor of the values using Euclid's algorithm. +inline uint64_t GreatestCommonDivisor64(uint64_t A, uint64_t B) { + while (B) { + uint64_t T = B; + B = A % B; + A = T; + } + return A; +} + +/// This function takes a 64-bit integer and returns the bit equivalent double. +inline double BitsToDouble(uint64_t Bits) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + double D; + static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes"); + memcpy(&D, &Bits, sizeof(Bits)); + return D; +} + +/// This function takes a 32-bit integer and returns the bit equivalent float. +inline float BitsToFloat(uint32_t Bits) { + // TODO: Use std::bit_cast once C++20 becomes available. + return standalone::c10::bit_cast(Bits); +} + +/// This function takes a double and returns the bit equivalent 64-bit integer. +/// Note that copying doubles around changes the bits of NaNs on some hosts, +/// notably x86, so this routine cannot be used if these bits are needed. +inline uint64_t DoubleToBits(double Double) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint64_t Bits; + static_assert(sizeof(uint64_t) == sizeof(double), "Unexpected type sizes"); + memcpy(&Bits, &Double, sizeof(Double)); + return Bits; +} + +/// This function takes a float and returns the bit equivalent 32-bit integer. +/// Note that copying floats around changes the bits of NaNs on some hosts, +/// notably x86, so this routine cannot be used if these bits are needed. +inline uint32_t FloatToBits(float Float) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint32_t Bits; + static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes"); + memcpy(&Bits, &Float, sizeof(Float)); + return Bits; +} + +/// A and B are either alignments or offsets. Return the minimum alignment that +/// may be assumed after adding the two together. +constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) { + // The largest power of 2 that divides both A and B. + // + // Replace "-Value" by "1+~Value" in the following commented code to avoid + // MSVC warning C4146 + // return (A | B) & -(A | B); + return (A | B) & (1 + ~(A | B)); +} + +/// Aligns \c Addr to \c Alignment bytes, rounding up. +/// +/// Alignment should be a power of two. This method rounds up, so +/// alignAddr(7, 4) == 8 and alignAddr(8, 4) == 8. +inline uintptr_t alignAddr(const void* Addr, size_t Alignment) { + assert( + Alignment && isPowerOf2_64((uint64_t)Alignment) && + "Alignment is not a power of two!"); + + assert((uintptr_t)Addr + Alignment - 1 >= (uintptr_t)Addr); + + return (((uintptr_t)Addr + Alignment - 1) & ~(uintptr_t)(Alignment - 1)); +} + +/// Returns the necessary adjustment for aligning \c Ptr to \c Alignment +/// bytes, rounding up. +inline size_t alignmentAdjustment(const void* Ptr, size_t Alignment) { + return alignAddr(Ptr, Alignment) - (uintptr_t)Ptr; +} + +/// Returns the next power of two (in 64-bits) that is strictly greater than A. +/// Returns zero on overflow. +inline uint64_t NextPowerOf2(uint64_t A) { + A |= (A >> 1); + A |= (A >> 2); + A |= (A >> 4); + A |= (A >> 8); + A |= (A >> 16); + A |= (A >> 32); + return A + 1; +} + +/// Returns the power of two which is less than or equal to the given value. +/// Essentially, it is a floor operation across the domain of powers of two. +inline uint64_t PowerOf2Floor(uint64_t A) { + if (!A) + return 0; + return 1ull << (63 - countLeadingZeros(A, ZB_Undefined)); +} + +/// Returns the power of two which is greater than or equal to the given value. +/// Essentially, it is a ceil operation across the domain of powers of two. +inline uint64_t PowerOf2Ceil(uint64_t A) { + if (!A) + return 0; + return NextPowerOf2(A - 1); +} + +/// Returns the next integer (mod 2**64) that is greater than or equal to +/// \p Value and is a multiple of \p Align. \p Align must be non-zero. +/// +/// If non-zero \p Skew is specified, the return value will be a minimal +/// integer that is greater than or equal to \p Value and equal to +/// \p Align * N + \p Skew for some integer N. If \p Skew is larger than +/// \p Align, its value is adjusted to '\p Skew mod \p Align'. +/// +/// Examples: +/// \code +/// alignTo(5, 8) = 8 +/// alignTo(17, 8) = 24 +/// alignTo(~0LL, 8) = 0 +/// alignTo(321, 255) = 510 +/// +/// alignTo(5, 8, 7) = 7 +/// alignTo(17, 8, 1) = 17 +/// alignTo(~0LL, 8, 3) = 3 +/// alignTo(321, 255, 42) = 552 +/// \endcode +inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew = 0) { + assert(Align != 0u && "Align can't be 0."); + Skew %= Align; + return (Value + Align - 1 - Skew) / Align * Align + Skew; +} + +/// Returns the next integer (mod 2**64) that is greater than or equal to +/// \p Value and is a multiple of \c Align. \c Align must be non-zero. +template +constexpr inline uint64_t alignTo(uint64_t Value) { + static_assert(Align != 0u, "Align must be non-zero"); + return (Value + Align - 1) / Align * Align; +} + +/// Returns the integer ceil(Numerator / Denominator). +inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) { + return alignTo(Numerator, Denominator) / Denominator; +} + +/// \c alignTo for contexts where a constant expression is required. +/// \sa alignTo +/// +/// \todo FIXME: remove when \c constexpr becomes really \c constexpr +template +struct AlignTo { + static_assert(Align != 0u, "Align must be non-zero"); + template + struct from_value { + static const uint64_t value = (Value + Align - 1) / Align * Align; + }; +}; + +/// Returns the largest uint64_t less than or equal to \p Value and is +/// \p Skew mod \p Align. \p Align must be non-zero +inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) { + assert(Align != 0u && "Align can't be 0."); + Skew %= Align; + return (Value - Skew) / Align * Align + Skew; +} + +/// Returns the offset to the next integer (mod 2**64) that is greater than +/// or equal to \p Value and is a multiple of \p Align. \p Align must be +/// non-zero. +inline uint64_t OffsetToAlignment(uint64_t Value, uint64_t Align) { + return alignTo(Value, Align) - Value; +} + +/// Sign-extend the number in the bottom B bits of X to a 32-bit integer. +/// Requires 0 < B <= 32. +template +constexpr inline int32_t SignExtend32(uint32_t X) { + static_assert(B > 0, "Bit width can't be 0."); + static_assert(B <= 32, "Bit width out of range."); + return int32_t(X << (32 - B)) >> (32 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 32-bit integer. +/// Requires 0 < B < 32. +inline int32_t SignExtend32(uint32_t X, unsigned B) { + assert(B > 0 && "Bit width can't be 0."); + assert(B <= 32 && "Bit width out of range."); + return int32_t(X << (32 - B)) >> (32 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 64-bit integer. +/// Requires 0 < B < 64. +template +constexpr inline int64_t SignExtend64(uint64_t x) { + static_assert(B > 0, "Bit width can't be 0."); + static_assert(B <= 64, "Bit width out of range."); + return int64_t(x << (64 - B)) >> (64 - B); +} + +/// Sign-extend the number in the bottom B bits of X to a 64-bit integer. +/// Requires 0 < B < 64. +inline int64_t SignExtend64(uint64_t X, unsigned B) { + assert(B > 0 && "Bit width can't be 0."); + assert(B <= 64 && "Bit width out of range."); + return int64_t(X << (64 - B)) >> (64 - B); +} + +/// Subtract two unsigned integers, X and Y, of type T and return the absolute +/// value of the result. +template +std::enable_if_t, T> AbsoluteDifference(T X, T Y) { + return std::max(X, Y) - std::min(X, Y); +} + +/// Add two unsigned integers, X and Y, of type T. Clamp the result to the +/// maximum representable value of T on overflow. ResultOverflowed indicates if +/// the result is larger than the maximum representable value of type T. +template +std::enable_if_t, T> +SaturatingAdd(T X, T Y, bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + // Hacker's Delight, p. 29 + T Z = X + Y; + Overflowed = (Z < X || Z < Y); + if (Overflowed) + return std::numeric_limits::max(); + else + return Z; +} + +/// Multiply two unsigned integers, X and Y, of type T. Clamp the result to the +/// maximum representable value of T on overflow. ResultOverflowed indicates if +/// the result is larger than the maximum representable value of type T. +template +std::enable_if_t, T> +SaturatingMultiply(T X, T Y, bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + + // Hacker's Delight, p. 30 has a different algorithm, but we don't use that + // because it fails for uint16_t (where multiplication can have undefined + // behavior due to promotion to int), and requires a division in addition + // to the multiplication. + + Overflowed = false; + + // Log2(Z) would be either Log2Z or Log2Z + 1. + // Special case: if X or Y is 0, Log2_64 gives -1, and Log2Z + // will necessarily be less than Log2Max as desired. + int Log2Z = Log2_64(X) + Log2_64(Y); + const T Max = std::numeric_limits::max(); + int Log2Max = Log2_64(Max); + if (Log2Z < Log2Max) { + return X * Y; + } + if (Log2Z > Log2Max) { + Overflowed = true; + return Max; + } + + // We're going to use the top bit, and maybe overflow one + // bit past it. Multiply all but the bottom bit then add + // that on at the end. + T Z = (X >> 1) * Y; + if (Z & ~(Max >> 1)) { + Overflowed = true; + return Max; + } + Z <<= 1; + if (X & 1) + return SaturatingAdd(Z, Y, ResultOverflowed); + + return Z; +} + +/// Multiply two unsigned integers, X and Y, and add the unsigned integer, A to +/// the product. Clamp the result to the maximum representable value of T on +/// overflow. ResultOverflowed indicates if the result is larger than the +/// maximum representable value of type T. +template +std::enable_if_t, T> +SaturatingMultiplyAdd(T X, T Y, T A, bool* ResultOverflowed = nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool Dummy; + bool& Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy; + + T Product = SaturatingMultiply(X, Y, &Overflowed); + if (Overflowed) + return Product; + + return SaturatingAdd(A, Product, &Overflowed); +} + +/// Use this rather than HUGE_VALF; the latter causes warnings on MSVC. +extern const float huge_valf; +} // namespace standalone::c10::llvm diff --git a/backends/aoti/slim/c10/util/overflows.h b/backends/aoti/slim/c10/util/overflows.h new file mode 100644 index 00000000000..5f636cd1a75 --- /dev/null +++ b/backends/aoti/slim/c10/util/overflows.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace standalone::c10 { +// In some versions of MSVC, there will be a compiler error when building. +// C4146: unary minus operator applied to unsigned type, result still unsigned +// C4804: unsafe use of type 'bool' in operation +// It can be addressed by disabling the following warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#pragma warning(disable : 4804) +#pragma warning(disable : 4018) +#endif + +// The overflow checks may involve float to int conversion which may +// trigger precision loss warning. Re-enable the warning once the code +// is fixed. See T58053069. +STANDALONE_CLANG_DIAGNOSTIC_PUSH() +#if STANDALONE_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +STANDALONE_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +// bool can be converted to any type. +// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: +// `error: comparison of constant '255' with boolean expression is always false` +// for `f > limit::max()` below +template +std::enable_if_t, bool> overflows( + From /*f*/, + bool strict_unsigned [[maybe_unused]] = false) { + return false; +} + +// skip isnan and isinf check for integral types +template +std::enable_if_t && !std::is_same_v, bool> +overflows(From f, bool strict_unsigned = false) { + using limit = std::numeric_limits::type>; + if constexpr (!limit::is_signed && std::numeric_limits::is_signed) { + // allow for negative numbers to wrap using two's complement arithmetic. + // For example, with uint8, this allows for `a - b` to be treated as + // `a + 255 * b`. + if (!strict_unsigned) { + return greater_than_max(f) || + (standalone::c10::is_negative(f) && + -static_cast(f) > static_cast(limit::max())); + } + } + return standalone::c10::less_than_lowest(f) || greater_than_max(f); +} + +template +std::enable_if_t, bool> overflows( + From f, + bool strict_unsigned [[maybe_unused]] = false) { + using limit = std::numeric_limits::type>; + if (limit::has_infinity && std::isinf(static_cast(f))) { + return false; + } + if (!limit::has_quiet_NaN && (f != f)) { + return true; + } + return f < limit::lowest() || f > limit::max(); +} + +STANDALONE_CLANG_DIAGNOSTIC_POP() + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +template +std::enable_if_t::value, bool> overflows( + From f, + bool strict_unsigned = false) { + // casts from complex to real are considered to overflow if the + // imaginary component is non-zero + if (!is_complex::value && f.imag() != 0) { + return true; + } + // Check for overflow componentwise + // (Technically, the imag overflow check is guaranteed to be false + // when !is_complex, but any optimizer worth its salt will be + // able to figure it out.) + return overflows< + typename scalar_value_type::type, + typename From::value_type>(f.real(), strict_unsigned) || + overflows< + typename scalar_value_type::type, + typename From::value_type>(f.imag(), strict_unsigned); +} +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/qint32.h b/backends/aoti/slim/c10/util/qint32.h new file mode 100644 index 00000000000..7951bfd240a --- /dev/null +++ b/backends/aoti/slim/c10/util/qint32.h @@ -0,0 +1,18 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * qint32 is for signed 32 bit quantized Tensors + */ +struct alignas(4) qint32 { + using underlying = int32_t; + int32_t val_; + qint32() = default; + STANDALONE_HOST_DEVICE explicit qint32(int32_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/qint8.h b/backends/aoti/slim/c10/util/qint8.h new file mode 100644 index 00000000000..53c1fdf465a --- /dev/null +++ b/backends/aoti/slim/c10/util/qint8.h @@ -0,0 +1,20 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * This is the data type for quantized Tensors. Right now we only have + * qint8 which is for 8 bit Tensors, and qint32 for 32 bit int Tensors, + * we might have 4 bit, 2 bit or 1 bit data types in the future. + */ +struct alignas(1) qint8 { + using underlying = int8_t; + int8_t val_; + qint8() = default; + STANDALONE_HOST_DEVICE explicit qint8(int8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/quint2x4.h b/backends/aoti/slim/c10/util/quint2x4.h new file mode 100644 index 00000000000..009802be7f2 --- /dev/null +++ b/backends/aoti/slim/c10/util/quint2x4.h @@ -0,0 +1,19 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * quint2x4 is for un-signed 2 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint2x4 { + using underlying = uint8_t; + uint8_t val_; + quint2x4() = default; + STANDALONE_HOST_DEVICE explicit quint2x4(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/quint4x2.h b/backends/aoti/slim/c10/util/quint4x2.h new file mode 100644 index 00000000000..b6812ab8fde --- /dev/null +++ b/backends/aoti/slim/c10/util/quint4x2.h @@ -0,0 +1,19 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * quint4x2 is for un-signed 4 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint4x2 { + using underlying = uint8_t; + uint8_t val_; + quint4x2() = default; + STANDALONE_HOST_DEVICE explicit quint4x2(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/quint8.h b/backends/aoti/slim/c10/util/quint8.h new file mode 100644 index 00000000000..4019765ca4a --- /dev/null +++ b/backends/aoti/slim/c10/util/quint8.h @@ -0,0 +1,18 @@ +#pragma once +#include + +#include + +namespace standalone::c10 { + +/** + * quint8 is for unsigned 8 bit quantized Tensors + */ +struct alignas(1) quint8 { + using underlying = uint8_t; + uint8_t val_; + quint8() = default; + STANDALONE_HOST_DEVICE explicit quint8(uint8_t val) : val_(val) {} +}; + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/c10/util/safe_numerics.h b/backends/aoti/slim/c10/util/safe_numerics.h new file mode 100644 index 00000000000..26a05c636aa --- /dev/null +++ b/backends/aoti/slim/c10/util/safe_numerics.h @@ -0,0 +1,94 @@ +#pragma once +#include + +#include + +// GCC has __builtin_mul_overflow from before it supported __has_builtin +#ifdef _MSC_VER +#define STANDALONE_HAS_BUILTIN_OVERFLOW() (0) +#include +#include +#else +#define STANDALONE_HAS_BUILTIN_OVERFLOW() (1) +#endif + +namespace standalone::c10 { + +STANDALONE_ALWAYS_INLINE bool +add_overflows(uint64_t a, uint64_t b, uint64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + return __builtin_add_overflow(a, b, out); +#else + unsigned long long tmp; +#if defined(_M_IX86) || defined(_M_X64) + auto carry = _addcarry_u64(0, a, b, &tmp); +#else + tmp = a + b; + unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); + auto carry = vector >> 63; +#endif + *out = tmp; + return carry; +#endif +} + +STANDALONE_ALWAYS_INLINE bool +mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + *out = a * b; + // This test isnt exact, but avoids doing integer division + return ( + (standalone::c10::llvm::countLeadingZeros(a) + + standalone::c10::llvm::countLeadingZeros(b)) < 64); +#endif +} + +STANDALONE_ALWAYS_INLINE bool +mul_overflows(int64_t a, int64_t b, int64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + volatile int64_t tmp = a * b; + *out = tmp; + if (a == 0 || b == 0) { + return false; + } + return !(a == tmp / b); +#endif +} + +template +bool safe_multiplies_u64(It first, It last, uint64_t* out) { +#if STANDALONE_HAS_BUILTIN_OVERFLOW() + uint64_t prod = 1; + bool overflow = false; + for (; first != last; ++first) { + overflow |= standalone::c10::mul_overflows(prod, *first, &prod); + } + *out = prod; + return overflow; +#else + uint64_t prod = 1; + uint64_t prod_log2 = 0; + bool is_zero = false; + for (; first != last; ++first) { + auto x = static_cast(*first); + prod *= x; + // log2(0) isn't valid, so need to track it specially + is_zero |= (x == 0); + prod_log2 += standalone::c10::llvm::Log2_64_Ceil(x); + } + *out = prod; + // This test isnt exact, but avoids doing integer division + return !is_zero && (prod_log2 >= 64); +#endif +} + +template +bool safe_multiplies_u64(const Container& c, uint64_t* out) { + return safe_multiplies_u64(c.begin(), c.end(), out); +} + +} // namespace standalone::c10 diff --git a/backends/aoti/slim/core/SlimTensor.h b/backends/aoti/slim/core/SlimTensor.h new file mode 100644 index 00000000000..69ac4fec65f --- /dev/null +++ b/backends/aoti/slim/core/SlimTensor.h @@ -0,0 +1,637 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace standalone::slim { + +class SlimTensor { + public: + SlimTensor( + Storage&& storage, + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + int64_t storage_offset = 0) + : storage_(std::move(storage)), + storage_offset_(storage_offset), + dtype_(dtype) { + set_sizes_and_strides(sizes, strides); + } + + // Default constructor - creates an undefined tensor + SlimTensor() + : storage_(Storage()), + storage_offset_(0), + numel_(0), + dtype_(standalone::c10::ScalarType::Float), + is_contiguous_(true) { + sizes_and_strides_.set_sizes({0}); + sizes_and_strides_.set_strides({1}); + } + + SlimTensor(const SlimTensor&) = default; + SlimTensor& operator=(const SlimTensor&) = default; + SlimTensor(SlimTensor&&) = default; + SlimTensor& operator=(SlimTensor&&) = default; + + ~SlimTensor() = default; + + void reset() { + // Decrement the refcount of the storage + storage_.reset(); + } + + // Accessors + Storage storage() const { + return storage_; + } + + size_t nbytes() const { + return numel() * itemsize(); + } + + size_t itemsize() const { + return standalone::c10::elementSize(dtype_); + } + + standalone::c10::IntArrayRef sizes() const { + return sizes_and_strides_.sizes_arrayref(); + } + + int64_t size(int64_t dim) const { + int64_t wrapped_dim = + standalone::c10::maybe_wrap_dim(dim, static_cast(this->dim())); + return sizes_and_strides_.size_at(static_cast(wrapped_dim)); + } + + standalone::c10::IntArrayRef strides() const { + return sizes_and_strides_.strides_arrayref(); + } + + int64_t stride(int64_t dim) const { + int64_t wrapped_dim = + standalone::c10::maybe_wrap_dim(dim, static_cast(this->dim())); + return sizes_and_strides_.stride_at(static_cast(wrapped_dim)); + } + + standalone::c10::ScalarType dtype() const { + return dtype_; + } + + const standalone::c10::Device& device() const { + return storage_->device(); + } + + standalone::c10::DeviceType device_type() const { + return storage_->device().type(); + } + + standalone::c10::DeviceIndex device_index() const { + return storage_->device().index(); + } + + int64_t storage_offset() const { + return storage_offset_; + } + + size_t numel() const { + return numel_; + } + + size_t dim() const { + return sizes_and_strides_.size(); + } + + void* data_ptr() const { + return static_cast(storage_->data()) + storage_offset_ * itemsize(); + } + + bool is_contiguous() const { + return is_contiguous_; + } + + bool is_empty() const { + return numel_ == 0; + } + + bool is_cuda() const { + return device().is_cuda(); + } + + bool is_cpu() const { + return device().is_cpu(); + } + + // Check if tensor is defined (not default-constructed) + bool defined() const { + return storage_.get() != nullptr; + } + + // Setters + void set_storage(Storage&& new_storage) { + storage_ = std::move(new_storage); + } + + void set_sizes_and_strides( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + std::optional storage_offset = std::nullopt) { + const int64_t new_dim = static_cast(sizes.size()); + STANDALONE_CHECK( + new_dim == static_cast(strides.size()), + "dimensionality of sizes (", + new_dim, + ") must match dimensionality of strides (", + strides.size(), + ")"); + + std::vector new_sizes = sizes.vec(); + std::vector new_strides = strides.vec(); + + // stride calculation logic + bool overflowed = false; + if (new_dim > 0) { + for (int64_t dim = new_dim - 1; dim >= 0; dim--) { + if (strides[dim] >= 0) { + new_strides[dim] = strides[dim]; + } else { + // for negative strides + if (dim == new_dim - 1) { + new_strides[dim] = 1; + } else { + overflowed |= standalone::c10::mul_overflows( + new_strides[dim + 1], + std::max(new_sizes[dim + 1], 1), + &new_strides[dim]); + } + } + } + } + STANDALONE_CHECK(!overflowed, "Stride calculation overflowed"); + + sizes_and_strides_.set_sizes(new_sizes); + sizes_and_strides_.set_strides(new_strides); + if (storage_offset.has_value()) { + storage_offset_ = *storage_offset; + } + + refresh_numel(); + refresh_contiguous(); + } + + void set_sizes_contiguous(standalone::c10::IntArrayRef new_size) { + sizes_and_strides_.set_sizes(new_size); + refresh_numel(); + empty_tensor_restride(standalone::c10::MemoryFormat::Contiguous); + } + + void empty_tensor_restride(standalone::c10::MemoryFormat memory_format); + + SlimTensor resize_( + standalone::c10::IntArrayRef sizes, + std::optional optional_memory_format); + + // Conversion operations + SlimTensor to(const standalone::c10::Device& device) const { + if (device == storage_->device()) { + return *this; + } + // Does not mutate the current tensor. Returns a new tensor + Storage new_storage(new MaybeOwningStorage(storage_->clone(device))); + return SlimTensor( + std::move(new_storage), + sizes_and_strides_.sizes_arrayref(), + sizes_and_strides_.strides_arrayref(), + dtype_, + storage_offset_); + } + + SlimTensor cpu() const { + return to(CPU_DEVICE); + } + + SlimTensor cuda() const { + return to(DEFAULT_CUDA_DEVICE); + } + + SlimTensor to(standalone::c10::ScalarType dtype) const { + STANDALONE_CHECK(false, "TBD: to(dtype)"); + } + + SlimTensor& copy_(const SlimTensor& other) { + STANDALONE_CHECK( + this->numel() == other.numel(), "copy_: numel of tensors must match"); + STANDALONE_CHECK(this->dtype() == other.dtype(), "copy_: dtype must match"); + + if (this->numel() == 0) { + return *this; + } + + // Case 1: Both tensors are contiguous. We can do a fast bulk copy. + if (this->is_contiguous() && other.is_contiguous()) { + storage_->copy_( + this->data_ptr(), other.data_ptr(), other.nbytes(), other.device()); + return *this; + } + + // Case 2: At least one tensor is non-contiguous, perform element-wise copy + // that respects both source and destination strides. + const size_t elem_size = standalone::c10::elementSize(dtype_); + char* dst_data = static_cast(this->data_ptr()); + const char* src_data = static_cast(other.data_ptr()); + + std::vector counter(this->dim(), 0); + for (size_t i = 0; i < this->numel(); i++) { + // Compute src offset in elements + int64_t src_offset = 0; + for (size_t d = 0; d < other.dim(); d++) { + src_offset += counter[d] * other.stride(d); + } + + // Compute dst offset in elements + int64_t dst_offset = 0; + for (size_t d = 0; d < this->dim(); d++) { + dst_offset += counter[d] * this->stride(d); + } + + // Copy elem_size bytes from src to dst + if (this->device().is_cpu() && other.device().is_cpu()) { + std::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size); + } else if (this->device().is_cuda() || other.device().is_cuda()) { +#if defined(USE_CUDA) + DeviceTraits::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size, + device(), // dst device + other.device() // src device + ); +#else + STANDALONE_CHECK(false, "copy_: no CUDA support"); +#endif + } + // Increment the multi-dimensional counter + for (int64_t d = static_cast(this->dim()) - 1; d >= 0; --d) { + counter[d]++; + if (counter[d] < this->size(d)) { + break; + } + counter[d] = 0; + } + } + return *this; + } + + SlimTensor& fill_(const c10::Scalar& value) { + // Fast path for byte patterns on contiguous tensors - use memset + if (value.equal(0) && this->is_contiguous()) { + if (this->device().is_cpu()) { + std::memset(this->data_ptr(), 0, this->nbytes()); + return *this; + } else if (this->device().is_cuda()) { +#ifdef USE_CUDA + cudaError_t err = cudaMemset(this->data_ptr(), 0, this->nbytes()); + STANDALONE_CHECK( + err == cudaSuccess, + "CUDA memset failed: ", + cudaGetErrorString(err)); + return *this; +#else + STANDALONE_CHECK(false, "CUDA support not available"); +#endif + } + } + + // Fallback to type-specific fill implementation + auto fill_value = [&](auto typed_value) { + using SType = decltype(typed_value); + if (this->device().is_cuda()) { +#ifdef USE_CUDA + if (this->is_contiguous()) { + // Fast path for contiguous tensors + if constexpr (std::is_same_v) { + // Special handling for bool since std::vector doesn't have + // data() + std::vector host_data(this->numel(), typed_value ? 1 : 0); + cudaError_t err = cudaMemcpy( + this->data_ptr(), + host_data.data(), + this->nbytes(), + cudaMemcpyHostToDevice); + STANDALONE_CHECK( + err == cudaSuccess, + "CUDA memcpy failed: ", + cudaGetErrorString(err)); + } else { + std::vector host_data(this->numel(), typed_value); + cudaError_t err = cudaMemcpy( + this->data_ptr(), + host_data.data(), + this->nbytes(), + cudaMemcpyHostToDevice); + STANDALONE_CHECK( + err == cudaSuccess, + "CUDA memcpy failed: ", + cudaGetErrorString(err)); + } + } else { + // Handle non-contiguous tensors by copying to CPU, filling, then + // copying back + SlimTensor cpu_tensor = this->to(CPU_DEVICE); + cpu_tensor.fill_(typed_value); + this->copy_(cpu_tensor); + } +#else + STANDALONE_CHECK(false, "CUDA support not available"); +#endif + } else if (this->device().is_cpu()) { + if (this->is_contiguous()) { + // Fast path for contiguous tensors + SType* data = static_cast(this->data_ptr()); + for (size_t i = 0; i < this->numel(); ++i) { + data[i] = typed_value; + } + } else { + // Handle non-contiguous tensors by respecting strides + const size_t elem_size = standalone::c10::elementSize(this->dtype_); + char* base_data = static_cast(this->data_ptr()); + + std::vector counter(this->dim(), 0); + for (size_t i = 0; i < this->numel(); ++i) { + // Compute offset in elements based on strides + int64_t offset = 0; + for (size_t d = 0; d < this->dim(); d++) { + offset += counter[d] * this->stride(d); + } + + // Set the value at the computed offset + SType* element_ptr = + reinterpret_cast(base_data + offset * elem_size); + *element_ptr = typed_value; + + // Increment the multi-dimensional counter + for (int64_t d = static_cast(this->dim()) - 1; d >= 0; + --d) { + counter[d]++; + if (counter[d] < this->size(d)) { + break; + } + counter[d] = 0; + } + } + } + } + }; + + switch (this->dtype()) { + case standalone::c10::ScalarType::Double: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Float: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Half: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::BFloat16: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Long: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Int: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Short: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Char: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Byte: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::Bool: + fill_value(value.to()); + break; + case standalone::c10::ScalarType::ComplexFloat: + fill_value(value.to>()); + break; + case standalone::c10::ScalarType::ComplexDouble: + fill_value(value.to>()); + break; + default: + STANDALONE_CHECK(false, "fill_: Unsupported dtype"); + } + return *this; + } + + SlimTensor clone() const { + return _clone_impl( + this->sizes(), this->strides(), this->dtype(), this->device()); + } + + SlimTensor clone_contiguous() const { + std::vector contig_strides = + standalone::slim::compute_contiguous_strides(this->sizes()); + return _clone_impl( + this->sizes(), contig_strides, this->dtype(), this->device()); + } + + // View operations + SlimTensor as_strided( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset) const; + SlimTensor as_strided_( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset); + + SlimTensor permute(standalone::c10::IntArrayRef dims) const; + + // Transpose operations + SlimTensor transpose() const; + SlimTensor transpose(int64_t dim0, int64_t dim1) const; + SlimTensor t() const; + + SlimTensor reshape(standalone::c10::IntArrayRef proposed_shape) const; + + SlimTensor narrow(int64_t dim, int64_t start, int64_t length) const; + + // Generic element access returning SlimTensor + SlimTensor operator[](standalone::c10::IntArrayRef indices) const { + STANDALONE_CHECK( + indices.size() <= this->dim(), + "Number of indices (", + indices.size(), + ") cannot exceed tensor dimensions (", + this->dim(), + ")"); + + if (indices.size() == this->dim()) { + // Full indexing - return 0-dimensional tensor + int64_t linear_index = 0; + for (size_t i = 0; i < indices.size(); ++i) { + int64_t idx = indices[i]; + int64_t size = this->size(i); + idx = standalone::c10::maybe_wrap_dim(idx, size); + linear_index += idx * this->stride(i); + } + // Create 0-dimensional tensor pointing to the indexed element + int64_t new_storage_offset = this->storage_offset_ + linear_index; + return SlimTensor( + Storage(this->storage_), {}, {}, this->dtype_, new_storage_offset); + } else { + // Partial indexing - return tensor with reduced dimensions + std::vector new_sizes; + std::vector new_strides; + int64_t offset_adjustment = 0; + + // Calculate offset from the provided indices + for (size_t i = 0; i < indices.size(); ++i) { + int64_t idx = indices[i]; + int64_t size = this->size(i); + idx = standalone::c10::maybe_wrap_dim(idx, size); + offset_adjustment += idx * this->stride(i); + } + + // Copy remaining dimensions + for (size_t i = indices.size(); i < this->dim(); ++i) { + new_sizes.push_back(this->size(i)); + new_strides.push_back(this->stride(i)); + } + + int64_t new_storage_offset = this->storage_offset_ + offset_adjustment; + return SlimTensor( + Storage(this->storage_), + new_sizes, + new_strides, + this->dtype_, + new_storage_offset); + } + } + + // Convenience overload for single index + SlimTensor operator[](int64_t index) const { + return (*this)[standalone::c10::IntArrayRef{index}]; + } + + // Convenience overloads for common multi-dimensional cases + SlimTensor operator[](std::initializer_list indices) const { + return (*this)[standalone::c10::IntArrayRef(indices)]; + } + + // Extract scalar value from 0-dimensional tensor + standalone::c10::Scalar item() const { + switch (this->dtype()) { + case standalone::c10::ScalarType::Double: + return this->item(); + case standalone::c10::ScalarType::Float: + return this->item(); + case standalone::c10::ScalarType::Half: + return this->item(); + case standalone::c10::ScalarType::BFloat16: + return this->item(); + case standalone::c10::ScalarType::Long: + return this->item(); + case standalone::c10::ScalarType::Int: + return this->item(); + case standalone::c10::ScalarType::Short: + return this->item(); + case standalone::c10::ScalarType::Char: + return this->item(); + case standalone::c10::ScalarType::Byte: + return this->item(); + case standalone::c10::ScalarType::Bool: + return this->item(); + case standalone::c10::ScalarType::ComplexFloat: + return this->item>(); + case standalone::c10::ScalarType::ComplexDouble: + return this->item>(); + default: + STANDALONE_CHECK(false, "item(): Unsupported dtype"); + } + } + + // Templated version to access 0-dimensional tensor + template + T item() const { + STANDALONE_CHECK( + this->dim() == 0, "item() can only be called on 0-dimensional tensors"); + STANDALONE_CHECK( + this->numel() == 1, "item() requires tensor to have exactly 1 element"); + + // For 0-dimensional tensors, directly access the single element at + // data_ptr() No need to compute linear index since there's only one element + const T* data = static_cast(this->data_ptr()); + return *data; + } + + private: + SlimTensor _clone_impl( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device) const { + Storage storage = new_storage(sizes, strides, dtype, device); + SlimTensor result = + SlimTensor(std::move(storage), sizes, strides, dtype, 0); + result.copy_(*this); + return result; + } + + void refresh_numel() { + numel_ = compute_numel(sizes_and_strides_.sizes_arrayref()); + } + + bool compute_is_contiguous() const { + return standalone::c10::_compute_contiguous( + sizes_and_strides_.sizes_arrayref(), + sizes_and_strides_.strides_arrayref(), + numel_); + } + + void refresh_contiguous() { + // In SlimTensor, we only care about the single is_contiguous_ flag. + // (because TensorImpl (aten) implementation has other stuff) + is_contiguous_ = compute_is_contiguous(); + } + + Storage storage_; // device_type_ and device_index_ are stored in storage_ + int64_t storage_offset_{0}; + standalone::c10::SizesAndStrides sizes_and_strides_; + // If sizes and strides are empty, the numel is 1!! However, most of the + // time, we will immediately set sizes to {0} and reset numel to 0. + // (Can't do that in the default initializers, because there's no way to + // spell "allocate a one-element array" for strides_). + size_t numel_{1}; + standalone::c10::ScalarType dtype_; + bool is_contiguous_{true}; + // NOLINTNEXTLINE(clang-diagnostic-unused-private-field) + std::array reserved_{0}; // padding to align to 8 bytes +}; + +} // namespace standalone::slim + +#include +#include diff --git a/backends/aoti/slim/core/SlimTensorResize-incl.h b/backends/aoti/slim/core/SlimTensorResize-incl.h new file mode 100644 index 00000000000..64c976aa5d8 --- /dev/null +++ b/backends/aoti/slim/core/SlimTensorResize-incl.h @@ -0,0 +1,174 @@ +#pragma once + +#include + +#include +#include +#include + +namespace standalone::slim { +inline void SlimTensor::empty_tensor_restride( + standalone::c10::MemoryFormat memory_format) { +#ifdef DEBUG + STANDALONE_INTERNAL_ASSERT( + compute_numel() == numel_, + "If you are seeing this error, that means empty_tensor_restride was " + "called before setting correct numel"); +#endif + switch (memory_format) { + case standalone::c10::MemoryFormat::Contiguous: { + // dim_ is a virtual call, don't repeat it + const auto dim_ = dim(); + sizes_and_strides_.resize(dim_); + if (dim_ > 0) { + bool overflowed = false; + const auto last_idx = dim_ - 1; + sizes_and_strides_.stride_at_unchecked(last_idx) = 1; + for (int64_t i = static_cast(last_idx) - 1; i >= 0; --i) { + overflowed |= standalone::c10::mul_overflows( + sizes_and_strides_.stride_at_unchecked(i + 1), + std::max(sizes_and_strides_.size_at_unchecked(i + 1), 1), + std::addressof(sizes_and_strides_.stride_at_unchecked(i))); + } + STANDALONE_CHECK(!overflowed, "Stride calculation overflowed"); + } + break; + } + case standalone::c10::MemoryFormat::ChannelsLast: { + STANDALONE_CHECK( + dim() == 4, "required rank 4 tensor to use channels_last format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_2d(sizes())); + break; + } + case standalone::c10::MemoryFormat::ChannelsLast3d: { + STANDALONE_CHECK( + dim() == 5, "required rank 5 tensor to use channels_last_3d format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_3d(sizes())); + break; + } + case standalone::c10::MemoryFormat::Preserve: + STANDALONE_CHECK(false, "unsupported memory format ", memory_format); + // Cleaning warning messages, no need to break as STANDALONE_CHECK(false) + // terminates flow. + // break; + case standalone::c10::MemoryFormat::NumOptions: + STANDALONE_INTERNAL_ASSERT( + false, "invalid memory format ", memory_format); + } + // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually + // exclusive see #24090 + refresh_contiguous(); +} + +inline void _resize_bytes( + MaybeOwningStorage* storage, + size_t new_size_bytes, + size_t storage_offset_in_bytes) { + STANDALONE_CHECK( + storage->is_resizable(), + "Trying to resize storage that is not resizable"); + + void* new_data = nullptr; + const c10::Device& device = storage->device(); + if (new_size_bytes > 0) { + if (device.is_cpu()) { + new_data = + DeviceTraits::allocate(new_size_bytes, device); + } else if (device.is_cuda()) { + new_data = + DeviceTraits::allocate(new_size_bytes, device); + } + } + + void* old_data = storage->data(); + const size_t old_capacity = storage->nbytes(); + const size_t copy_capacity = std::min(new_size_bytes, old_capacity); + if (old_data != nullptr && copy_capacity > 0) { + if (device.is_cpu()) { + DeviceTraits::memcpy( + static_cast(new_data) + storage_offset_in_bytes, + static_cast(old_data) + storage_offset_in_bytes, + copy_capacity, + device, + device); + } else if (device.is_cuda()) { + DeviceTraits::memcpy( + static_cast(new_data) + storage_offset_in_bytes, + static_cast(old_data) + storage_offset_in_bytes, + copy_capacity, + device, + device); + } + } + + storage->free_data(); + storage->set_data_ptr_noswap(new_data); + storage->set_nbytes(new_size_bytes); +} + +inline void _maybe_resize_storage(SlimTensor* self, int64_t new_size_bytes) { + if (self->numel() == 0) { + return; + } + + const Storage& storage = self->storage(); + if (!storage) { + Storage new_storage(new MaybeOwningStorage(self->device(), new_size_bytes)); + self->set_storage(std::move(new_storage)); + } else if (new_size_bytes > static_cast(self->nbytes())) { + _resize_bytes( + storage.get(), + new_size_bytes, + self->storage_offset() * self->itemsize()); + } +} + +inline SlimTensor* _resize_impl_( + SlimTensor* self, + standalone::c10::IntArrayRef sizes, + std::optional strides, + bool resize_storage) { + if (self->sizes() == sizes && + (!strides || self->strides() == strides.value())) { + return self; + } + + const auto itemsize = self->itemsize(); + const auto storage_offset = self->storage_offset(); + int64_t storage_size = 1; + if (strides) { + self->set_sizes_and_strides(sizes, *strides); + storage_size = + compute_storage_nbytes(sizes, *strides, itemsize, storage_offset); + } else { + self->set_sizes_contiguous(sizes); + storage_size = + compute_storage_nbytes_contiguous(sizes, itemsize, storage_offset); + } + + if (resize_storage) { + _maybe_resize_storage(self, storage_size); + } + + return self; +} + +inline SlimTensor SlimTensor::resize_( + standalone::c10::IntArrayRef sizes, + std::optional optional_memory_format) { + _resize_impl_(this, sizes, /*strides=*/std::nullopt, true); + + if (optional_memory_format.has_value()) { + standalone::c10::MemoryFormat memory_format = + static_cast( + optional_memory_format.value()); + STANDALONE_CHECK( + memory_format != standalone::c10::MemoryFormat::Preserve, + "Unsupported memory format", + memory_format); + this->empty_tensor_restride(memory_format); + } + return *this; +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/core/SlimTensorView-incl.h b/backends/aoti/slim/core/SlimTensorView-incl.h new file mode 100644 index 00000000000..0df4c4705f1 --- /dev/null +++ b/backends/aoti/slim/core/SlimTensorView-incl.h @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include +#include +#include + +namespace standalone::slim { +inline SlimTensor SlimTensor::as_strided( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset) const { + SlimTensor result = *this; + result.as_strided_(sizes, strides, storage_offset); + return result; +} + +inline SlimTensor SlimTensor::as_strided_( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + int64_t storage_offset) { + STANDALONE_CHECK( + sizes.size() == strides.size(), + "as_strided: number of sizes (", + sizes.size(), + ") must equal number of strides (", + strides.size(), + ")"); + for (size_t i = 0; i < sizes.size(); ++i) { + STANDALONE_CHECK( + sizes[i] >= 0, + "as_strided: size at dimension ", + i, + " is negative: ", + sizes[i]); + } + STANDALONE_CHECK( + storage_offset >= 0, + "as_strided: storage_offset must be non-negative, got: ", + storage_offset); + + this->set_sizes_and_strides(sizes, strides, storage_offset); + return *this; +} + +inline SlimTensor SlimTensor::permute(standalone::c10::IntArrayRef dims) const { + const size_t ndim = this->dim(); + STANDALONE_CHECK( + ndim == static_cast(dims.size()), + "permute: dims length must be equal to tensor.dim()") + + standalone::c10::ArrayRef old_sizes = this->sizes(); + standalone::c10::ArrayRef old_strides = this->strides(); + std::vector new_sizes = old_sizes.vec(); + std::vector new_strides = old_strides.vec(); + std::vector seen_dims(ndim, false); + + for (size_t i = 0; i < ndim; i++) { + int64_t d = standalone::c10::maybe_wrap_dim(dims[i], ndim); + STANDALONE_CHECK(!seen_dims[d], "permute: duplicate dims are not allowed"); + seen_dims[d] = true; + new_sizes[i] = old_sizes[d]; + new_strides[i] = old_strides[d]; + } + + SlimTensor result = *this; + result.as_strided_(new_sizes, new_strides, this->storage_offset()); + return result; +} + +inline SlimTensor SlimTensor::transpose() const { + STANDALONE_CHECK(dim() == 2, "transpose() can only be called on 2D tensors"); + return permute({1, 0}); +} + +inline SlimTensor SlimTensor::transpose(int64_t dim0, int64_t dim1) const { + const size_t ndim = this->dim(); + std::vector dims; + for (size_t i = 0; i < ndim; i++) { + dims.push_back(static_cast(i)); + } + + // Wrap dimensions and swap them + dim0 = standalone::c10::maybe_wrap_dim(dim0, ndim); + dim1 = standalone::c10::maybe_wrap_dim(dim1, ndim); + std::swap(dims[dim0], dims[dim1]); + + return permute(dims); +} + +inline SlimTensor SlimTensor::t() const { + return transpose(); +} + +inline SlimTensor SlimTensor::reshape( + standalone::c10::IntArrayRef proposed_shape) const { + std::vector final_shape_vec = + infer_size(proposed_shape, this->numel()); + + // `compute_stride` return the proper strides to use if this + // `reshape` can be just a view. + std::optional> new_strides_opt = + compute_stride(this->sizes(), this->strides(), final_shape_vec); + + // create a view if possible + if (new_strides_opt.has_value()) { + SlimTensor result = *this; + result.as_strided_( + final_shape_vec, new_strides_opt.value(), this->storage_offset()); + return result; + } + + // if a view is not possible, create a contiguous clone and reshape that + SlimTensor contiguous_clone = this->clone_contiguous(); + // after cloning, the tensor is already contiguous. We just need to update + // its metadata to reflect the new shape. This is effectively a view of + // the new contiguous clone + contiguous_clone.set_sizes_contiguous(final_shape_vec); + return contiguous_clone; +} + +inline SlimTensor SlimTensor::narrow(int64_t dim, int64_t start, int64_t length) + const { + STANDALONE_CHECK( + this->dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + dim = standalone::c10::maybe_wrap_dim(dim, static_cast(this->dim())); + start = standalone::c10::maybe_wrap_dim( + start, static_cast(this->size(dim))); + + STANDALONE_CHECK(length >= 0, "narrow(): length must be non-negative."); + int64_t end = start + length; + STANDALONE_CHECK( + end <= this->size(dim), + "Invalid range to narrow. range(", + start, + ", ", + start + length, + ") must be a subset of range(0, ", + this->size(dim), + ")."); + + SlimTensor result = *this; + int64_t new_storage_offset = + this->storage_offset() + start * this->stride(dim); + std::vector new_sizes = this->sizes().vec(); + new_sizes[dim] = length; + result.as_strided_(new_sizes, this->strides(), new_storage_offset); + return result; +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h new file mode 100644 index 00000000000..4230a0d2b0a --- /dev/null +++ b/backends/aoti/slim/core/Storage.h @@ -0,0 +1,307 @@ +#pragma once +#include +#include +#include +#include + +#ifdef USE_CUDA +#include +#include +#endif + +#include +#include +#include +#include +#include +#include + +namespace standalone::slim { +using DeleterFn = void (*)(void*); + +namespace detail { +inline void noop(void*) {} +} // namespace detail + +const standalone::c10::Device CPU_DEVICE = + standalone::c10::Device(standalone::c10::DeviceType::CPU, 0); + +const standalone::c10::Device DEFAULT_CUDA_DEVICE = + standalone::c10::Device(standalone::c10::DeviceType::CUDA, 0); + +// standalone::c10::Device traits template for device-specific operations +template +struct DeviceTraits; + +// CPU specialization +template <> +struct DeviceTraits { + static void* allocate( + size_t nbytes, + const standalone::c10::Device& device = CPU_DEVICE) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + return malloc(nbytes); + } + + static void free(void* ptr) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + std::free(ptr); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const standalone::c10::Device& dst_device, + const standalone::c10::Device& src_device) { + std::memcpy(dst, src, nbytes); + } +}; + +// CUDA specialization +#ifdef USE_CUDA +template <> +struct DeviceTraits { + static void* allocate(size_t nbytes, const standalone::c10::Device& device) { + standalone::slim::cuda::CUDAGuard guard(device); + void* data = nullptr; + STANDALONE_CUDA_CHECK(cudaMalloc(&data, nbytes)); + return data; + } + + static void free(void* ptr) { + STANDALONE_CUDA_CHECK_WARN(cudaFree(ptr)); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const standalone::c10::Device& dst_device, + const standalone::c10::Device& src_device) { + // Determine the direction + cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; + standalone::c10::Device cuda_device = + dst_device; // Default to destination device + + if (src_device.is_cpu()) { + direction = cudaMemcpyHostToDevice; + } else if (dst_device.is_cpu()) { + direction = cudaMemcpyDeviceToHost; + cuda_device = src_device; // Use source CUDA device + } else { + STANDALONE_CHECK( + src_device.index() == dst_device.index(), + "CUDA memcpy failed across different device indices: ", + src_device.index(), + "!=", + dst_device.index()); + } + // Set up CUDA context for the appropriate device + standalone::slim::cuda::CUDAGuard guard(cuda_device); + STANDALONE_CUDA_CHECK(cudaMemcpy(dst, src, nbytes, direction)); + } +}; +#else +template <> +struct DeviceTraits { + static void* allocate(size_t nbytes, const standalone::c10::Device& device) { + STANDALONE_CHECK(false, "Build with USE_CUDA=1 to enable CUDA support"); + } + + static void free(void* ptr) { + STANDALONE_WARN("Build with USE_CUDA=1 to enable CUDA support"); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const standalone::c10::Device& dst_device, + const standalone::c10::Device& src_device) { + STANDALONE_CHECK(false, "Build with USE_CUDA=1 to enable CUDA support"); + } +}; +#endif + +// Storage can be either owning or non-owning. For AOTI-generated intermediate +// tensors, the storage is always owning. For constant tensors, the storage is +// non-owning. +class MaybeOwningStorage { + public: + MaybeOwningStorage(const standalone::c10::Device& device, size_t nbytes) + : device_(device), capacity_(nbytes), is_owning_(true) { + // Allocating memory here so owning_ has to be true. + if (device.is_cpu()) { + data_ = DeviceTraits::allocate( + nbytes, device); + deleter_ = DeviceTraits::free; + } else if (device.is_cuda()) { + data_ = DeviceTraits::allocate( + nbytes, device); + deleter_ = DeviceTraits::free; + } else { + STANDALONE_CHECK(false, "Unsupported device type"); + } + } + + MaybeOwningStorage( + const standalone::c10::Device& device, + void* data, + size_t nbytes) + : device_(device), data_(data), capacity_(nbytes), is_owning_(false) { + // data pointer is not owned by this object + } + + MaybeOwningStorage() = delete; + MaybeOwningStorage& operator=(const MaybeOwningStorage&) = delete; + MaybeOwningStorage(const MaybeOwningStorage&) = delete; + + // Move constructor + MaybeOwningStorage(MaybeOwningStorage&& other) noexcept + : device_(other.device_), + data_(other.data_), + capacity_(other.capacity_), + deleter_(other.deleter_), + is_owning_(other.is_owning_) { + // Leave the moved-from object in a safe state + other.data_ = nullptr; + other.capacity_ = 0; + other.deleter_ = detail::noop; + other.is_owning_ = false; + } + + // Move assignment operator + MaybeOwningStorage& operator=(MaybeOwningStorage&& other) noexcept { + if (this != &other) { + // Free current resources + free_data(); + + // Transfer ownership from other + device_ = other.device_; + data_ = other.data_; + capacity_ = other.capacity_; + deleter_ = other.deleter_; + is_owning_ = other.is_owning_; + + // Leave the moved-from object in a safe state + other.data_ = nullptr; + other.capacity_ = 0; + other.deleter_ = detail::noop; + other.is_owning_ = false; + } + return *this; + } + + ~MaybeOwningStorage() { + free_data(); + } + + void copy_( + void* dst_data_ptr, + void* src_data_ptr, + size_t nbytes, + const standalone::c10::Device& src_device) { + STANDALONE_CHECK( + dst_data_ptr, "Storage clone failed: dst_data_ptr can not be nullptr") + STANDALONE_CHECK( + src_data_ptr, "Storage clone failed: src_data_ptr can not be nullptr") + if (dst_data_ptr == src_data_ptr) { + return; + } + + if (device_.is_cpu() && src_device.is_cpu()) { + // CPU to CPU copy + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } else { + // At least one of the devices is CUDA + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } + } + + MaybeOwningStorage clone(const standalone::c10::Device& device) const { + STANDALONE_CHECK( + data_, "Storage clone failed: source data can not be nullptr") + // Create a new owning storage with the specified device and same capacity + MaybeOwningStorage cloned_storage(device, capacity_); + + // Copy the data from the current storage to the new storage + if (device_.is_cpu() && device.is_cpu()) { + // CPU to CPU copy + DeviceTraits::memcpy( + cloned_storage.data_, data_, capacity_, device, device_); + } else { + // At least one of the devices is CUDA + DeviceTraits::memcpy( + cloned_storage.data_, data_, capacity_, device, device_); + } + + return cloned_storage; + } + + void* data() const { + // Always return nullptr for zero-sized storage + if (capacity_ == 0) { + return nullptr; + } + return data_; + } + + const standalone::c10::Device& device() const { + return device_; + } + + size_t nbytes() const { + return this->capacity_; + } + + void unsafe_set_to_non_owning() { + // This is only used when interacting with at::Tensor. When testing + // standalone AOTI from pytorch, we need to convert the output SlimTensor + // into at::Tensor, which means the storage ownership should be stolen by + // at::Tensor. When all the SlimTensors referencing the storage are + // destroyed, the storage should NOT be freed. + deleter_ = detail::noop; + is_owning_ = false; + } + + bool is_resizable() const { + return is_owning_; + } + + void free_data() { + if (data_ != nullptr) { + deleter_(data_); + } + } + + void set_data_ptr_noswap(void* new_data) { + data_ = new_data; + } + + void set_nbytes(size_t new_nbytes) { + capacity_ = new_nbytes; + } + + private: + standalone::c10::Device device_ = CPU_DEVICE; + void* data_ = nullptr; + size_t capacity_ = 0; + DeleterFn deleter_ = detail::noop; + bool is_owning_ = false; +}; + +using Storage = SharedPtr; + +inline Storage new_storage( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + size_t nbytes = compute_storage_nbytes( + sizes, strides, standalone::c10::elementSize(dtype), 0); + return Storage(new MaybeOwningStorage(device, nbytes)); +} +} // namespace standalone::slim diff --git a/backends/aoti/slim/cuda/Exception.h b/backends/aoti/slim/cuda/Exception.h new file mode 100644 index 00000000000..d777352c1d7 --- /dev/null +++ b/backends/aoti/slim/cuda/Exception.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef USE_CUDA + +#include +#include +#include + +#include +#include +#include + +#include +#include + +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + ET_CHECK_MSG(__err == cudaSuccess, "%s", cudaGetErrorString(__err)); \ + } while (0) + +#define ET_CUDA_CHECK_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (ET_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + ET_LOG(Warning, "CUDA warning: %s", cudaGetErrorString(__err)); \ + } \ + } while (0) + +#endif // USE_CUDA diff --git a/backends/aoti/slim/cuda/Guard.h b/backends/aoti/slim/cuda/Guard.h new file mode 100644 index 00000000000..c9b2441b148 --- /dev/null +++ b/backends/aoti/slim/cuda/Guard.h @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace standalone::slim::cuda { + +// Thread-local stream management +namespace detail { +inline thread_local std:: + unordered_map + current_streams_; +} + +/// Set the current CUDA stream for the specified device +inline void setCurrentCUDAStream( + cudaStream_t stream, + standalone::c10::DeviceIndex device_index = -1) { + if (device_index == -1) { + // Get current device if not specified + int current_device; + STANDALONE_CUDA_CHECK(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + detail::current_streams_[device_index] = stream; +} + +/// Get the current CUDA stream for the specified device +inline cudaStream_t getCurrentCUDAStream( + standalone::c10::DeviceIndex device_index = -1) { + if (device_index == -1) { + // Get current device if not specified + int current_device; + STANDALONE_CUDA_CHECK(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + auto it = detail::current_streams_.find(device_index); + if (it != detail::current_streams_.end()) { + return it->second; + } + + // Create a new stream and set it as current + cudaStream_t stream; + STANDALONE_CUDA_CHECK(cudaStreamCreate(&stream)); + setCurrentCUDAStream(stream, device_index); + return stream; +} + +struct CUDAGuard { + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit CUDAGuard() = delete; + + /// Set the current CUDA device to the passed device index. + explicit CUDAGuard(standalone::c10::DeviceIndex device_index) { + set_index(device_index); + } + + /// Sets the current CUDA device to the passed device. Errors if the passed + /// device is not a CUDA device. + explicit CUDAGuard(standalone::c10::Device device) { + STANDALONE_CHECK( + device.is_cuda(), + "Expected a CUDA device for CUDAGuard, but got ", + device); + set_index(device.index()); + } + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move is not allowed (there is no uninitialized state) + CUDAGuard(CUDAGuard&& other) = delete; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + ~CUDAGuard() { + // Restore the original device if necessary + if (original_device_index_ != current_device_index_) { + STANDALONE_CUDA_CHECK_WARN(cudaSetDevice(original_device_index_)); + } + } + + /// Sets the CUDA device to the given device index. + void set_index(standalone::c10::DeviceIndex device_index) { + int orig_index = -1; + STANDALONE_CUDA_CHECK(cudaGetDevice(&orig_index)); + + original_device_index_ = orig_index; + current_device_index_ = device_index; + if (current_device_index_ != original_device_index_) { + STANDALONE_CUDA_CHECK(cudaSetDevice(current_device_index_)); + } + } + + private: + /// The guard for the current device. + standalone::c10::DeviceIndex original_device_index_; + standalone::c10::DeviceIndex current_device_index_; +}; + +struct CUDAStreamGuard { + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit CUDAStreamGuard() = delete; + + /// Set the current CUDA stream to the passed stream on the specified device. + explicit CUDAStreamGuard( + cudaStream_t stream, + standalone::c10::DeviceIndex device_index) + : device_guard_(device_index) { + set_stream(stream, device_index); + } + + // Copy is not allowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + // Move is not allowed (there is no uninitialized state) + CUDAStreamGuard(CUDAStreamGuard&& other) = delete; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; + + ~CUDAStreamGuard() { + // Restore the original stream for the device + setCurrentCUDAStream(original_stream_, device_index_); + // Device guard will automatically restore the original device + } + + /// Sets the CUDA stream to the given stream on the specified device. + void set_stream( + cudaStream_t stream, + standalone::c10::DeviceIndex device_index) { + // Store the original stream for this device + original_stream_ = getCurrentCUDAStream(device_index); + current_stream_ = stream; + device_index_ = device_index; + + // Set the new stream as current for this device + setCurrentCUDAStream(stream, device_index); + } + + /// Get the current guarded stream + cudaStream_t stream() const { + return current_stream_; + } + + /// Get the device index being guarded + standalone::c10::DeviceIndex device_index() const { + return device_index_; + } + + private: + /// The device guard that handles device switching + CUDAGuard device_guard_; + /// The original stream that was current before this guard + cudaStream_t original_stream_ = nullptr; + /// The current stream being guarded + cudaStream_t current_stream_ = nullptr; + /// The device index for this stream guard + standalone::c10::DeviceIndex device_index_; +}; + +} // namespace standalone::slim::cuda diff --git a/backends/aoti/slim/factory/Empty.h b/backends/aoti/slim/factory/Empty.h new file mode 100644 index 00000000000..bbd4996b84c --- /dev/null +++ b/backends/aoti/slim/factory/Empty.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace standalone::slim { +// The returned SlimTensor owns the underlying storage +inline SlimTensor empty_strided( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + Storage storage = new_storage(sizes, strides, dtype, device); + return SlimTensor(std::move(storage), sizes, strides, dtype, 0); +} + +inline SlimTensor empty( + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + std::vector contig_strides = + standalone::slim::compute_contiguous_strides(sizes); + Storage storage = new_storage(sizes, contig_strides, dtype, device); + return SlimTensor(std::move(storage), sizes, contig_strides, dtype, 0); +} + +inline SlimTensor empty_like(const SlimTensor& other) { + return empty_strided( + other.sizes(), other.strides(), other.dtype(), other.device()); +} +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/Factory.h b/backends/aoti/slim/factory/Factory.h new file mode 100644 index 00000000000..5e172bc9f6a --- /dev/null +++ b/backends/aoti/slim/factory/Factory.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace standalone::slim { +inline SlimTensor zeros( + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + SlimTensor tensor = empty(sizes, dtype, device); + tensor.fill_(standalone::c10::Scalar(0)); + return tensor; +} + +inline SlimTensor zeros_like(const SlimTensor& other) { + return zeros(other.sizes(), other.dtype(), other.device()); +} + +inline SlimTensor ones( + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE) { + SlimTensor tensor = empty(sizes, dtype, device); + tensor.fill_(standalone::c10::Scalar(1)); + return tensor; +} + +inline SlimTensor ones_like(const SlimTensor& other) { + return ones(other.sizes(), other.dtype(), other.device()); +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/FromBlob.h b/backends/aoti/slim/factory/FromBlob.h new file mode 100644 index 00000000000..d1877f7f31d --- /dev/null +++ b/backends/aoti/slim/factory/FromBlob.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace standalone::slim { + +// The returned SlimTensor does not own the underlying storage +inline SlimTensor from_blob( + void* data, + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE, + int64_t storage_offset = 0) { + STANDALONE_CHECK(data != nullptr, "data pointer can not be nullptr"); + + Storage storage(new MaybeOwningStorage( + device, + data, + compute_storage_nbytes( + sizes, strides, elementSize(dtype), storage_offset))); + return SlimTensor(std::move(storage), sizes, strides, dtype, storage_offset); +} + +inline SlimTensor from_blob( + void* data, + standalone::c10::IntArrayRef sizes, + standalone::c10::ScalarType dtype, + const standalone::c10::Device& device = CPU_DEVICE, + int64_t storage_offset = 0) { + std::vector contig_strides = + standalone::slim::compute_contiguous_strides(sizes); + return from_blob(data, sizes, contig_strides, dtype, device, storage_offset); +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/FromScalar.h b/backends/aoti/slim/factory/FromScalar.h new file mode 100644 index 00000000000..223f734d940 --- /dev/null +++ b/backends/aoti/slim/factory/FromScalar.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace standalone::slim { + +inline SlimTensor scalar_to_tensor( + const standalone::c10::Scalar& s, + const standalone::c10::Device& device = CPU_DEVICE) { + SlimTensor result = empty_strided({}, {}, s.type(), device); + result.fill_(s); + return result; +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/factory/Pad.h b/backends/aoti/slim/factory/Pad.h new file mode 100644 index 00000000000..4d7fef731bd --- /dev/null +++ b/backends/aoti/slim/factory/Pad.h @@ -0,0 +1,106 @@ +#pragma once + +#include + +namespace standalone::slim { + +inline SlimTensor constant_pad_nd( + const SlimTensor& self, + standalone::c10::IntArrayRef pad, + const standalone::c10::Scalar& value) { + STANDALONE_CHECK(pad.size() % 2 == 0, "Length of pad must be even"); + + standalone::c10::IntArrayRef input_sizes = self.sizes(); + int64_t l_inp = self.dim(); + int64_t l_pad = static_cast(pad.size()) / 2; + int64_t l_diff = l_inp - l_pad; + + STANDALONE_CHECK( + l_pad <= l_inp, + "Length of pad should be no more than twice the input's dimension."); + + bool all_pads_non_positive = true; + SlimTensor c_input = self; + for (int64_t i = l_diff; i < l_inp; i++) { + int64_t pad_idx = 2 * (l_inp - i - 1); + + if (pad[pad_idx] < 0) { + c_input = + c_input.narrow(i, -pad[pad_idx], c_input.size(i) + pad[pad_idx]); + } else if (pad[pad_idx] != 0) { + all_pads_non_positive = false; + } + if (pad[pad_idx + 1] < 0) { + c_input = c_input.narrow(i, 0, c_input.size(i) + pad[pad_idx + 1]); + } else if (pad[pad_idx + 1] != 0) { + all_pads_non_positive = false; + } + } + + // if none of the pads are positive we can optimize and just return the result + // of calling .narrow() on the input + if (all_pads_non_positive) { + return c_input.clone_contiguous(); + } + + // calculate the new shape for the output tensor + std::vector new_shape; + new_shape.reserve(l_diff); + for (int64_t i = 0; i < l_diff; i++) { + new_shape.emplace_back(input_sizes[i]); + } + + for (const auto i : standalone::c10::irange((size_t)l_pad)) { + auto pad_idx = pad.size() - ((i + 1) * 2); + auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; + STANDALONE_CHECK( + new_dim > 0, + "The input size ", + input_sizes[l_diff + i], + ", plus negative padding ", + pad[pad_idx], + " and ", + pad[pad_idx + 1], + " resulted in a negative output size, " + "which is invalid. Check dimension ", + l_diff + i, + " of your input."); + new_shape.emplace_back(new_dim); + } + + SlimTensor output = empty(new_shape, self.dtype(), self.device()); + output.fill_(value); + + // create a view into the center of the output tensor + SlimTensor c_output = output; + for (const auto i : standalone::c10::irange(l_diff, l_inp)) { + auto pad_idx = 2 * (l_inp - i - 1); + if (pad[pad_idx] > 0) { + c_output = + c_output.narrow(i, pad[pad_idx], c_output.size(i) - pad[pad_idx]); + } + if (pad[pad_idx + 1] > 0) { + c_output = c_output.narrow(i, 0, c_output.size(i) - pad[pad_idx + 1]); + } + } + // copy the input data into the center view + c_output.copy_(c_input); + return output; +} + +inline SlimTensor pad( + const SlimTensor& self, + standalone::c10::IntArrayRef pad, + std::string_view mode, + std::optional value) { + if (mode == "constant") { + return constant_pad_nd(self, pad, value.value_or(0.0)); + } + STANDALONE_CHECK( + false, + "Unsupported padding mode: ", + mode, + ". Only constant mode is available."); +} + +} // namespace standalone::slim diff --git a/backends/aoti/slim/targets.bzl b/backends/aoti/slim/targets.bzl new file mode 100644 index 00000000000..62db9452984 --- /dev/null +++ b/backends/aoti/slim/targets.bzl @@ -0,0 +1,81 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define SlimTensor library targets. + + SlimTensor is a lightweight tensor implementation for AOTI (Ahead-of-Time Inference) + that provides a minimal, efficient tensor abstraction for ExecuTorch CUDA backend. + + This is a direct port from torchnative/standalone/slim with minimal modifications. + """ + + # Utility library (SharedPtr, SizeUtil) + runtime.cxx_library( + name = "util", + exported_headers = glob(["util/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/backends/aoti/slim/c10:c10", + ], + ) + + # Core SlimTensor library (CPU only) + runtime.cxx_library( + name = "core", + exported_headers = glob(["core/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":util", + "//executorch/backends/aoti/slim/c10:c10", + ], + ) + + # Factory functions library + runtime.cxx_library( + name = "factory", + exported_headers = glob(["factory/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":core", + "//executorch/backends/aoti/slim/c10:c10", + ], + ) + + # CUDA support library + runtime.cxx_library( + name = "cuda", + exported_headers = glob(["cuda/*.h"]), + visibility = ["@EXECUTORCH_CLIENTS"], + exported_preprocessor_flags = ["-DUSE_CUDA"], + exported_deps = [ + ":core", + "//executorch/backends/aoti/slim/c10:c10", + "//executorch/backends/aoti/slim/c10:c10_cuda", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + + # CPU-only SlimTensor library (no CUDA dependencies) + runtime.cxx_library( + name = "slim_tensor_cpu", + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":core", + ":factory", + ":util", + ], + ) + + # Full SlimTensor library (with CUDA support) + runtime.cxx_library( + name = "slim_tensor", + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":core", + ":factory", + ":cuda", + ":util", + ], + ) diff --git a/backends/aoti/slim/tests/TARGETS b/backends/aoti/slim/tests/TARGETS new file mode 100644 index 00000000000..f91c46c0f20 --- /dev/null +++ b/backends/aoti/slim/tests/TARGETS @@ -0,0 +1,5 @@ +load("targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/tests/targets.bzl b/backends/aoti/slim/tests/targets.bzl new file mode 100644 index 00000000000..0f0eb843c7d --- /dev/null +++ b/backends/aoti/slim/tests/targets.bzl @@ -0,0 +1,31 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") + +def slim_tensor_cpp_unittest(name, extra_deps = []): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti/slim:slim_tensor_cpu", + ] + extra_deps, + ) + +def slim_tensor_cuda_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti/slim:slim_tensor", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) + +def define_common_targets(): + """Define test targets for SlimTensor library.""" + slim_tensor_cpp_unittest("slim_tensor_basic") + slim_tensor_cuda_cpp_unittest("slim_tensor_cuda") diff --git a/backends/aoti/slim/tests/test_slim_tensor_basic.cpp b/backends/aoti/slim/tests/test_slim_tensor_basic.cpp new file mode 100644 index 00000000000..37b6ccb240d --- /dev/null +++ b/backends/aoti/slim/tests/test_slim_tensor_basic.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace standalone::slim { +namespace { + +TEST(SlimTensorBasicTest, EmptyTensorCreation) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.dim(), 3); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.size(2), 4); + EXPECT_EQ(tensor.numel(), 24); + EXPECT_EQ(tensor.dtype(), standalone::c10::ScalarType::Float); + EXPECT_TRUE(tensor.is_contiguous()); +} + +TEST(SlimTensorBasicTest, EmptyTensorContiguousStrides) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.stride(0), 12); + EXPECT_EQ(tensor.stride(1), 4); + EXPECT_EQ(tensor.stride(2), 1); +} + +TEST(SlimTensorBasicTest, ZerosTensorCreation) { + auto tensor = zeros({3, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.numel(), 9); + float* data = static_cast(tensor.data_ptr()); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(data[i], 0.0f); + } +} + +TEST(SlimTensorBasicTest, OnesTensorCreation) { + auto tensor = ones({2, 2}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.numel(), 4); + float* data = static_cast(tensor.data_ptr()); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(data[i], 1.0f); + } +} + +TEST(SlimTensorBasicTest, FillTensor) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + tensor.fill_(5.0f); + float* data = static_cast(tensor.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(data[i], 5.0f); + } +} + +TEST(SlimTensorBasicTest, FromBlobNonOwning) { + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor = from_blob( + data.data(), {2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + EXPECT_EQ(tensor.dim(), 2); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.numel(), 6); + EXPECT_EQ(tensor.data_ptr(), data.data()); +} + +TEST(SlimTensorBasicTest, Clone) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + tensor.fill_(3.14f); + + auto cloned = tensor.clone(); + EXPECT_NE(cloned.data_ptr(), tensor.data_ptr()); + EXPECT_EQ(cloned.sizes(), tensor.sizes()); + EXPECT_EQ(cloned.strides(), tensor.strides()); + + float* cloned_data = static_cast(cloned.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(cloned_data[i], 3.14f); + } +} + +TEST(SlimTensorBasicTest, CopyFrom) { + auto src = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + src.fill_(2.5f); + + auto dst = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + dst.copy_(src); + + float* dst_data = static_cast(dst.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(dst_data[i], 2.5f); + } +} + +TEST(SlimTensorBasicTest, Reshape) { + auto tensor = empty({2, 6}, standalone::c10::ScalarType::Float, CPU_DEVICE); + tensor.fill_(1.0f); + + auto reshaped = tensor.reshape({3, 4}); + EXPECT_EQ(reshaped.dim(), 2); + EXPECT_EQ(reshaped.size(0), 3); + EXPECT_EQ(reshaped.size(1), 4); + EXPECT_EQ(reshaped.numel(), 12); +} + +TEST(SlimTensorBasicTest, Transpose) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto transposed = tensor.transpose(0, 1); + EXPECT_EQ(transposed.size(0), 3); + EXPECT_EQ(transposed.size(1), 2); +} + +TEST(SlimTensorBasicTest, Permute) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto permuted = tensor.permute({2, 0, 1}); + EXPECT_EQ(permuted.size(0), 4); + EXPECT_EQ(permuted.size(1), 2); + EXPECT_EQ(permuted.size(2), 3); +} + +TEST(SlimTensorBasicTest, Narrow) { + auto tensor = empty({10}, standalone::c10::ScalarType::Float, CPU_DEVICE); + for (int i = 0; i < 10; ++i) { + static_cast(tensor.data_ptr())[i] = static_cast(i); + } + + auto narrowed = tensor.narrow(0, 2, 5); + EXPECT_EQ(narrowed.dim(), 1); + EXPECT_EQ(narrowed.size(0), 5); + + float* narrowed_data = static_cast(narrowed.data_ptr()); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(narrowed_data[i], static_cast(i + 2)); + } +} + +TEST(SlimTensorBasicTest, EmptyLike) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto empty_like_tensor = empty_like(tensor); + EXPECT_EQ(empty_like_tensor.sizes(), tensor.sizes()); + EXPECT_EQ(empty_like_tensor.dtype(), tensor.dtype()); + EXPECT_EQ(empty_like_tensor.device(), tensor.device()); +} + +TEST(SlimTensorBasicTest, ZerosLike) { + auto tensor = empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + auto zeros_tensor = zeros_like(tensor); + EXPECT_EQ(zeros_tensor.sizes(), tensor.sizes()); + + float* data = static_cast(zeros_tensor.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(data[i], 0.0f); + } +} + +} // namespace +} // namespace standalone::slim diff --git a/backends/aoti/slim/tests/test_slim_tensor_cuda.cpp b/backends/aoti/slim/tests/test_slim_tensor_cuda.cpp new file mode 100644 index 00000000000..571d4f99893 --- /dev/null +++ b/backends/aoti/slim/tests/test_slim_tensor_cuda.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include + +namespace standalone::slim { +namespace { + +class SlimTensorCUDATest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA device not available"; + } + } +}; + +TEST_F(SlimTensorCUDATest, EmptyCUDATensorCreation) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.dim(), 3); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.size(2), 4); + EXPECT_EQ(tensor.numel(), 24); + EXPECT_EQ(tensor.device().type(), standalone::c10::DeviceType::CUDA); + EXPECT_TRUE(tensor.is_contiguous()); +} + +TEST_F(SlimTensorCUDATest, ZerosCUDATensor) { + auto tensor = + zeros({3, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.numel(), 9); + EXPECT_EQ(tensor.device().type(), standalone::c10::DeviceType::CUDA); + + std::vector host_data(9); + cudaMemcpy( + host_data.data(), + tensor.data_ptr(), + 9 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(host_data[i], 0.0f); + } +} + +TEST_F(SlimTensorCUDATest, OnesCUDATensor) { + auto tensor = + ones({2, 2}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.numel(), 4); + + std::vector host_data(4); + cudaMemcpy( + host_data.data(), + tensor.data_ptr(), + 4 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(host_data[i], 1.0f); + } +} + +TEST_F(SlimTensorCUDATest, FillCUDATensor) { + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + tensor.fill_(5.0f); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + tensor.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(host_data[i], 5.0f); + } +} + +TEST_F(SlimTensorCUDATest, CloneCUDATensor) { + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + tensor.fill_(3.14f); + + auto cloned = tensor.clone(); + EXPECT_NE(cloned.data_ptr(), tensor.data_ptr()); + EXPECT_EQ(cloned.sizes(), tensor.sizes()); + EXPECT_EQ(cloned.device(), tensor.device()); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + cloned.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_FLOAT_EQ(host_data[i], 3.14f); + } +} + +TEST_F(SlimTensorCUDATest, CopyCUDAToCUDA) { + auto src = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + src.fill_(2.5f); + + auto dst = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + dst.copy_(src); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + dst.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(host_data[i], 2.5f); + } +} + +TEST_F(SlimTensorCUDATest, CopyCPUToCUDA) { + auto cpu_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + cpu_tensor.fill_(1.5f); + + auto cuda_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_tensor); + + std::vector host_data(6); + cudaMemcpy( + host_data.data(), + cuda_tensor.data_ptr(), + 6 * sizeof(float), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(host_data[i], 1.5f); + } +} + +TEST_F(SlimTensorCUDATest, CopyCUDAToCPU) { + auto cuda_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.fill_(4.5f); + + auto cpu_tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, CPU_DEVICE); + cpu_tensor.copy_(cuda_tensor); + + float* data = static_cast(cpu_tensor.data_ptr()); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(data[i], 4.5f); + } +} + +TEST_F(SlimTensorCUDATest, CUDAGuard) { + cuda::CUDAGuard guard(0); + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + EXPECT_EQ(tensor.device().type(), standalone::c10::DeviceType::CUDA); +} + +TEST_F(SlimTensorCUDATest, ReshapeCUDATensor) { + auto tensor = + empty({2, 6}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + auto reshaped = tensor.reshape({3, 4}); + EXPECT_EQ(reshaped.dim(), 2); + EXPECT_EQ(reshaped.size(0), 3); + EXPECT_EQ(reshaped.size(1), 4); + EXPECT_EQ(reshaped.device(), tensor.device()); +} + +TEST_F(SlimTensorCUDATest, TransposeCUDATensor) { + auto tensor = + empty({2, 3}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + auto transposed = tensor.transpose(0, 1); + EXPECT_EQ(transposed.size(0), 3); + EXPECT_EQ(transposed.size(1), 2); + EXPECT_EQ(transposed.device(), tensor.device()); +} + +TEST_F(SlimTensorCUDATest, PermuteCUDATensor) { + auto tensor = + empty({2, 3, 4}, standalone::c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + auto permuted = tensor.permute({2, 0, 1}); + EXPECT_EQ(permuted.size(0), 4); + EXPECT_EQ(permuted.size(1), 2); + EXPECT_EQ(permuted.size(2), 3); + EXPECT_EQ(permuted.device(), tensor.device()); +} + +} // namespace +} // namespace standalone::slim diff --git a/backends/aoti/slim/tests/test_type_convert.cpp b/backends/aoti/slim/tests/test_type_convert.cpp new file mode 100644 index 00000000000..a93c7d27d70 --- /dev/null +++ b/backends/aoti/slim/tests/test_type_convert.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace executorch::backends::aoti::slim { +namespace { + +TEST(TypeConvertTest, ToInt32Vec) { + std::vector int64_vec = {1, 2, 3, 4, 5}; + auto int32_vec = to_int32_vec(int64_vec); + + EXPECT_EQ(int32_vec.size(), 5); + EXPECT_EQ(int32_vec[0], 1); + EXPECT_EQ(int32_vec[1], 2); + EXPECT_EQ(int32_vec[2], 3); + EXPECT_EQ(int32_vec[3], 4); + EXPECT_EQ(int32_vec[4], 5); +} + +TEST(TypeConvertTest, ToInt64Vec) { + std::vector int32_vec = {10, 20, 30}; + auto int64_vec = to_int64_vec(int32_vec); + + EXPECT_EQ(int64_vec.size(), 3); + EXPECT_EQ(int64_vec[0], 10); + EXPECT_EQ(int64_vec[1], 20); + EXPECT_EQ(int64_vec[2], 30); +} + +TEST(TypeConvertTest, ToInt32VecEmpty) { + std::vector empty_vec; + auto result = to_int32_vec(empty_vec); + EXPECT_TRUE(result.empty()); +} + +TEST(TypeConvertTest, ToInt64VecEmpty) { + std::vector empty_vec; + auto result = to_int64_vec(empty_vec); + EXPECT_TRUE(result.empty()); +} + +TEST(TypeConvertTest, SafeNarrowInt64ToInt32) { + int64_t value = 42; + int32_t result = safe_narrow(value); + EXPECT_EQ(result, 42); +} + +TEST(TypeConvertTest, SafeNarrowInt32ToInt16) { + int32_t value = 1000; + int16_t result = safe_narrow(value); + EXPECT_EQ(result, 1000); +} + +TEST(TypeConvertTest, ToInt32VecLargeValues) { + std::vector int64_vec = {1000000, 2000000, 3000000}; + auto int32_vec = to_int32_vec(int64_vec); + + EXPECT_EQ(int32_vec.size(), 3); + EXPECT_EQ(int32_vec[0], 1000000); + EXPECT_EQ(int32_vec[1], 2000000); + EXPECT_EQ(int32_vec[2], 3000000); +} + +TEST(TypeConvertTest, ToInt64VecFromUint32) { + std::vector uint32_vec = {100, 200, 300}; + auto int64_vec = to_int64_vec(uint32_vec); + + EXPECT_EQ(int64_vec.size(), 3); + EXPECT_EQ(int64_vec[0], 100); + EXPECT_EQ(int64_vec[1], 200); + EXPECT_EQ(int64_vec[2], 300); +} + +} // namespace +} // namespace executorch::backends::aoti::slim diff --git a/backends/aoti/slim/util/SharedPtr.h b/backends/aoti/slim/util/SharedPtr.h new file mode 100644 index 00000000000..9ad565d9ab9 --- /dev/null +++ b/backends/aoti/slim/util/SharedPtr.h @@ -0,0 +1,222 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace standalone::slim { + +/** + * NonAtomicSharedPtr - A lightweight, non-thread-safe shared pointer + * implementation + * + * This class provides shared ownership semantics similar to std::shared_ptr but + * without atomic operations, making it faster in single-threaded contexts where + * thread safety is not required. + * + * Primary Use Cases: + * 1. Intermediate SlimTensor Storage Management: + * - Manages temporary tensors created during model execution + * - These tensors are confined to single-threaded execution contexts + * - Avoids the overhead of atomic reference counting in std::shared_ptr + * + * 2. Input/Output Tensor References: + * - Provides reference counting for input/output tensors + * - Tensor lifetimes are externally managed (not by AOTI-generated code) + * - Uses dummy deleters to prevent premature deallocation + * - Reference counting still occurs but actual cleanup is deferred + * + * Performance Benefits: + * - Non-atomic reference counting reduces CPU overhead + * - Smaller memory footprint compared to std::shared_ptr + * - Optimized for single-threaded tensor operations + * + * Thread Safety: NOT THREAD-SAFE + * - Must only be used in single-threaded contexts + * - Concurrent access will result in undefined behavior + * - Define the USE_MULTI_THREAD macro to use std::shared_ptr instead when + * thread safety is required + */ +template +class NonAtomicSharedPtr { + private: + struct ControlBlock { + int count = 1; + T* ptr; + using Deleter = void (*)(T*); + Deleter deleter; + + ControlBlock(T* p, Deleter d) : ptr(p), deleter(d) {} + ControlBlock(const ControlBlock&) = delete; + ControlBlock& operator=(const ControlBlock&) = delete; + ControlBlock(ControlBlock&&) = delete; + ControlBlock& operator=(ControlBlock&&) = delete; + + ~ControlBlock() { + if (ptr) { + deleter(ptr); + } + } + }; + + ControlBlock* cb_; + + static void default_deleter(T* p) { + delete p; + } + + void cleanup() { + if (cb_ && --cb_->count == 0) { + delete cb_; + } + cb_ = nullptr; + } + + public: + // Default constructor + NonAtomicSharedPtr() noexcept : cb_(nullptr) {} + + // Constructor from raw pointer + explicit NonAtomicSharedPtr( + T* p, + typename ControlBlock::Deleter d = default_deleter) + : cb_(p ? new ControlBlock(p, d) : nullptr) {} + + // Copy constructor + NonAtomicSharedPtr(const NonAtomicSharedPtr& other) noexcept + : cb_(other.cb_) { + if (cb_) { + ++cb_->count; + } + } + + // Move constructor + NonAtomicSharedPtr(NonAtomicSharedPtr&& other) noexcept : cb_(other.cb_) { + other.cb_ = nullptr; + } + + // Destructor + ~NonAtomicSharedPtr() { + cleanup(); + } + + // Copy assignment + NonAtomicSharedPtr& operator=(const NonAtomicSharedPtr& other) noexcept { + if (this != &other) { + cleanup(); + cb_ = other.cb_; + if (cb_) { + ++cb_->count; + } + } + return *this; + } + + // Move assignment + NonAtomicSharedPtr& operator=(NonAtomicSharedPtr&& other) noexcept { + if (this != &other) { + cleanup(); + cb_ = other.cb_; + other.cb_ = nullptr; + } + return *this; + } + + // Modifiers + void reset( + T* p = nullptr, + typename ControlBlock::Deleter d = default_deleter) { + *this = NonAtomicSharedPtr(p, d); + } + + void swap(NonAtomicSharedPtr& other) noexcept { + std::swap(cb_, other.cb_); + } + + // Observers + T* get() const noexcept { + return cb_ ? cb_->ptr : nullptr; + } + T& operator*() const { + STANDALONE_CHECK(cb_, "Dereferencing null NonAtomicSharedPtr"); + return *cb_->ptr; + } + T* operator->() const { + STANDALONE_CHECK(cb_, "Accessing member of null NonAtomicSharedPtr"); + return cb_->ptr; + } + long use_count() const noexcept { + return cb_ ? cb_->count : 0; + } + explicit operator bool() const noexcept { + return cb_ != nullptr; + } + + // Friend swap for ADL + friend void swap(NonAtomicSharedPtr& a, NonAtomicSharedPtr& b) noexcept { + a.swap(b); + } + + // Comparison operators + friend bool operator==( + const NonAtomicSharedPtr& lhs, + const NonAtomicSharedPtr& rhs) noexcept { + return lhs.get() == rhs.get(); + } + + friend bool operator!=( + const NonAtomicSharedPtr& lhs, + const NonAtomicSharedPtr& rhs) noexcept { + return !(lhs == rhs); + } + + friend bool operator==( + const NonAtomicSharedPtr& lhs, + std::nullptr_t) noexcept { + return lhs.get() == nullptr; + } + + friend bool operator!=( + const NonAtomicSharedPtr& lhs, + std::nullptr_t) noexcept { + return lhs.get() != nullptr; + } + + friend bool operator==( + std::nullptr_t, + const NonAtomicSharedPtr& rhs) noexcept { + return rhs.get() == nullptr; + } + + friend bool operator!=( + std::nullptr_t, + const NonAtomicSharedPtr& rhs) noexcept { + return rhs.get() != nullptr; + } +}; + +#ifdef USE_MULTI_THREAD +template +using SharedPtr = ::std::shared_ptr; + +// make_shared for std::shared_ptr +template +std::shared_ptr make_shared(Args&&... args) { + return std::make_shared(std::forward(args)...); +} + +#else +template +using SharedPtr = ::standalone::slim::NonAtomicSharedPtr; + +// make_shared for NonAtomicSharedPtr +template +NonAtomicSharedPtr make_shared(Args&&... args) { + return NonAtomicSharedPtr(new T(std::forward(args)...)); +} + +#endif // USE_MULTI_THREAD +} // namespace standalone::slim diff --git a/backends/aoti/slim/util/SizeUtil.h b/backends/aoti/slim/util/SizeUtil.h new file mode 100644 index 00000000000..d22416cd176 --- /dev/null +++ b/backends/aoti/slim/util/SizeUtil.h @@ -0,0 +1,283 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace standalone::slim { +#ifndef STANDALONE_MOBILE +inline constexpr uint64_t storage_max() { + // int64_t and size_t are used somewhat inconsistently throughout ATen. + // To be safe, storage size calculations must fit in both types. + constexpr auto int64_max = + static_cast(std::numeric_limits::max()); + constexpr auto size_max = + static_cast(std::numeric_limits::max()); + return std::min(int64_max, size_max); +} + +/** + * Compute the number of elements based on the sizes of a + * tensor. Catches integer overflow that may occur when a tensor + * using a sparse layout has multiple dimensions with large sizes. + */ +inline int64_t safe_compute_numel(standalone::c10::IntArrayRef sizes) { + uint64_t n = 1; + bool overflowed = standalone::c10::safe_multiplies_u64(sizes, &n); + overflowed |= (n > storage_max()); + STANDALONE_CHECK(!overflowed, "numel: integer multiplication overflow"); + return static_cast(n); +} + +inline std::vector safe_compute_contiguous_strides( + c10::IntArrayRef sizes) { + int64_t ndim = static_cast(sizes.size()); + std::vector strides(ndim); + if (ndim > 0) { + uint64_t stride = 1; + bool overflowed = false; + for (int64_t i = ndim - 1; i >= 0; i--) { + strides[i] = static_cast(stride); + if (sizes[i] != 0) { + uint64_t new_stride = 0; + overflowed |= c10::mul_overflows( + stride, static_cast(sizes[i]), &new_stride); + stride = new_stride; + } + } + STANDALONE_CHECK( + !overflowed, "contiguous_strides: stride multiplication overflow"); + } + return strides; +} +#endif // STANDALONE_MOBILE + +inline int64_t compute_numel(standalone::c10::IntArrayRef sizes) { +#ifndef STANDALONE_MOBILE + // Use overflow checks if supported by the compiler + return safe_compute_numel(sizes); +#else + return standalone::c10::multiply_integers(sizes); +#endif +} + +// named computeStorageNbytesContiguous in c10 +inline size_t compute_storage_nbytes_contiguous( + standalone::c10::IntArrayRef sizes, + size_t itemsize_bytes, + size_t storage_offset) { +// Ignore overflow checks on mobile +#ifndef STANDALONE_MOBILE + uint64_t size = 1; + bool overflowed = standalone::c10::safe_multiplies_u64(sizes, &size); + overflowed |= standalone::c10::add_overflows(size, storage_offset, &size); + overflowed |= standalone::c10::mul_overflows(size, itemsize_bytes, &size); + overflowed |= size > storage_max(); + STANDALONE_CHECK( + !overflowed, "Storage size calculation overflowed with sizes=", sizes); + return static_cast(size); +#else + const auto numel = multiply_integers(sizes); + return itemsize_bytes * (storage_offset + numel); +#endif +} + +// named computeStorageNbytes in c10 +inline size_t compute_storage_nbytes( + standalone::c10::IntArrayRef sizes, + standalone::c10::IntArrayRef strides, + size_t itemsize_bytes, + size_t storage_offset) { + STANDALONE_CHECK( + sizes.size() == strides.size(), + "dimensionality of sizes (", + sizes.size(), + ") must match dimensionality of strides (", + strides.size(), + ")"); + +// Ignore overflow checks on mobile +#ifndef STANDALONE_MOBILE + // size of the underlying storage is 1 bigger than the offset + // of the last element according to stride + uint64_t size = storage_offset + 1; + bool overflowed = false; + for (const auto i : standalone::c10::irange(sizes.size())) { + if (sizes[i] == 0) { + return 0; + } + + uint64_t strided_size = 0; + overflowed |= + standalone::c10::mul_overflows(strides[i], sizes[i] - 1, &strided_size); + overflowed |= standalone::c10::add_overflows(size, strided_size, &size); + } + overflowed |= standalone::c10::mul_overflows(size, itemsize_bytes, &size); + overflowed |= size > storage_max(); + STANDALONE_CHECK( + !overflowed, + "Storage size calculation overflowed with sizes=", + sizes, + " and strides=", + strides); + return static_cast(size); +#else + // size of the underlying storage is 1 bigger than the offset + // of the last element according to stride + uint64_t size = 1; + for (const auto i : standalone::c10::irange(sizes.size())) { + if (sizes[i] == 0) { + return 0; + } + + size += strides[i] * (sizes[i] - 1); + } + return itemsize_bytes * (storage_offset + size); +#endif +} + +inline std::vector compute_contiguous_strides(c10::IntArrayRef sizes) { +#ifndef STANDALONE_MOBILE + return safe_compute_contiguous_strides(sizes); +#else + int64_t ndim = static_cast(sizes.size()); + std::vector strides(ndim); + if (ndim > 0) { + int64_t stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + if (sizes[i] != 0) { + stride *= sizes[i]; + } + } + } + return strides; +#endif +} + +// calculates the final concrete shape by also filling in at most one '-1' +// dimension. +inline std::vector infer_size( + standalone::c10::IntArrayRef shape, + int64_t numel) { + int64_t new_size = 1; + std::optional infer_dim; + std::vector result_shape; + result_shape.reserve(shape.size()); + + size_t ndim = shape.size(); + bool overflowed = false; + for (size_t dim = 0; dim < ndim; dim++) { + if (shape[dim] == -1) { + STANDALONE_CHECK( + !infer_dim.has_value(), "only one dimension can be inferred"); + infer_dim = dim; + result_shape.push_back(-1); // placeholder + } else { + STANDALONE_CHECK(shape[dim] >= 0, "invalid shape dimension ", shape[dim]); + overflowed |= + standalone::c10::mul_overflows(new_size, shape[dim], &new_size); + result_shape.push_back(shape[dim]); + } + } + STANDALONE_CHECK(!overflowed, "shape calculation overflowed"); + + if (infer_dim.has_value()) { + STANDALONE_CHECK( + new_size != 0, + "cannot reshape tensor of 0 elements into shape with -1"); + STANDALONE_CHECK( + numel % new_size == 0, "shape is invalid for input size ", numel); + result_shape[*infer_dim] = numel / new_size; + } else { + STANDALONE_CHECK( + numel == new_size, "shape is invalid for input of size ", numel); + } + return result_shape; +} + +// it determines if a reshape is possible as a view. +// If so, it returns the new strides +// If not, it returns an empty optional +inline std::optional> compute_stride( + standalone::c10::IntArrayRef old_sizes, + standalone::c10::IntArrayRef old_strides, + standalone::c10::IntArrayRef new_sizes) { + if (old_sizes.empty()) { + return std::vector(new_sizes.size(), 1); + } + + // NOTE: stride is arbitrary in the numel() == 0 case; + // to match NumPy behavior we copy the strides if the size matches, otherwise + // we use the stride as if it were computed via resize. + // This could perhaps be combined with the below code, but the complexity + // didn't seem worth it. + size_t numel = compute_numel(old_sizes); + if (numel == 0 && old_sizes == new_sizes) { + return old_strides.vec(); + } + + int64_t new_sizes_len = static_cast(new_sizes.size()); + std::vector new_strides(new_sizes_len); + if (numel == 0) { + for (int64_t view_d = new_sizes_len - 1; view_d >= 0; view_d--) { + if (view_d == new_sizes_len - 1) { + new_strides[view_d] = 1; + } else { + new_strides[view_d] = std::max(new_sizes[view_d + 1], 1) * + new_strides[view_d + 1]; + } + } + return new_strides; + } + + int64_t view_d = new_sizes_len - 1; + int64_t chunk_base_stride = old_strides.back(); + int64_t tensor_numel = 1; + int64_t view_numel = 1; + bool overflowed = false; + for (int64_t tensor_d = static_cast(old_sizes.size()) - 1; + tensor_d >= 0; + tensor_d--) { + // TODO: ask if this could lead to overflow by any chance? + // even if so, overflow is not handled in the aten implementation + overflowed |= standalone::c10::mul_overflows( + tensor_numel, old_sizes[tensor_d], &tensor_numel); + + bool is_chunk_end = (tensor_d == 0) || + (old_sizes[tensor_d - 1] != 1 && + old_strides[tensor_d - 1] != tensor_numel * chunk_base_stride); + + if (is_chunk_end) { + while (view_d >= 0 && + (view_numel < tensor_numel || new_sizes[view_d] == 1)) { + new_strides[view_d] = view_numel * chunk_base_stride; + view_numel *= new_sizes[view_d]; + view_d--; + } + if (view_numel != tensor_numel) { + return std::nullopt; // Not viewable + } + if (tensor_d > 0) { + chunk_base_stride = old_strides[tensor_d - 1]; + tensor_numel = 1; + view_numel = 1; + } + } + } + STANDALONE_CHECK(!overflowed, "overflowed while computing strides"); + + if (view_d != -1) { + return std::nullopt; // not viewable + } + return new_strides; +} + +} // namespace standalone::slim