Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/aoti/slim/c10/core/Contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <algorithm>
#include <cstdint>

namespace standalone::c10 {
namespace executorch::backends::aoti::slim::c10 {

template <typename T>
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
Expand Down Expand Up @@ -148,4 +148,4 @@ bool _compute_non_overlapping_and_dense(
return true;
}

} // namespace standalone::c10
} // namespace executorch::backends::aoti::slim::c10
15 changes: 9 additions & 6 deletions backends/aoti/slim/c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

// Copied from c10/core/DeviceType.h with some modifications

namespace standalone::c10 {
namespace executorch::backends::aoti::slim::c10 {
namespace detail {
enum class DeviceStringParsingState {
kSTART,
Expand Down Expand Up @@ -341,18 +341,21 @@ inline std::ostream& operator<<(std::ostream& stream, const Device& device) {
stream << device.str();
return stream;
}
} // namespace standalone::c10
} // namespace executorch::backends::aoti::slim::c10

namespace std {
template <>
struct hash<standalone::c10::Device> {
size_t operator()(standalone::c10::Device d) const noexcept {
struct hash<executorch::backends::aoti::slim::c10::Device> {
size_t operator()(
executorch::backends::aoti::slim::c10::Device d) const noexcept {
// Are you here because this static assert failed? Make sure you ensure
// that the bitmasking code below is updated accordingly!
static_assert(
sizeof(standalone::c10::DeviceType) == 1, "DeviceType is not 8-bit");
sizeof(executorch::backends::aoti::slim::c10::DeviceType) == 1,
"DeviceType is not 8-bit");
static_assert(
sizeof(standalone::c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
sizeof(executorch::backends::aoti::slim::c10::DeviceIndex) == 1,
"DeviceIndex is not 8-bit");
// Note [Hazard when concatenating signed integers]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// We must first convert to a same-sized unsigned type, before promoting to
Expand Down
9 changes: 5 additions & 4 deletions backends/aoti/slim/c10/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include <executorch/backends/aoti/slim/c10/util/Exception.h>

namespace standalone::c10 {
namespace executorch::backends::aoti::slim::c10 {
enum class DeviceType : int8_t {
CPU = 0,
CUDA = 1, // CUDA.
Expand Down Expand Up @@ -122,12 +122,13 @@ inline std::ostream& operator<<(std::ostream& stream, DeviceType type) {
stream << DeviceTypeName(type, /* lower case */ true);
return stream;
}
} // namespace standalone::c10
} // namespace executorch::backends::aoti::slim::c10

namespace std {
template <>
struct hash<standalone::c10::DeviceType> {
std::size_t operator()(standalone::c10::DeviceType k) const {
struct hash<executorch::backends::aoti::slim::c10::DeviceType> {
std::size_t operator()(
executorch::backends::aoti::slim::c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
Expand Down
4 changes: 2 additions & 2 deletions backends/aoti/slim/c10/core/Layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <cstdint>
#include <ostream>

namespace standalone::c10 {
namespace executorch::backends::aoti::slim::c10 {
enum class Layout : int8_t {
Strided,
Sparse,
Expand Down Expand Up @@ -50,4 +50,4 @@ inline std::ostream& operator<<(std::ostream& stream, c10::Layout layout) {
}
}

} // namespace standalone::c10
} // namespace executorch::backends::aoti::slim::c10
6 changes: 3 additions & 3 deletions backends/aoti/slim/c10/core/MemoryFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
// Regardless of input tensors format, the output should be in channels_last
// format.

namespace standalone::c10 {
namespace executorch::backends::aoti::slim::c10 {
enum class MemoryFormat : int8_t {
Contiguous,
Preserve,
Expand All @@ -38,7 +38,7 @@ enum class MemoryFormat : int8_t {
// the memory format could be preserved, and it was switched to old default
// behaviour of contiguous
#define LEGACY_CONTIGUOUS_MEMORY_FORMAT \
::standalone::c10::get_contiguous_memory_format()
::executorch::backends::aoti::slim::c10::get_contiguous_memory_format()

inline MemoryFormat get_contiguous_memory_format() {
return MemoryFormat::Contiguous;
Expand Down Expand Up @@ -288,4 +288,4 @@ inline bool is_channels_last_strides_3d(
return is_channels_last_strides_3d<int64_t>(sizes, strides);
}

} // namespace standalone::c10
} // namespace executorch::backends::aoti::slim::c10
68 changes: 37 additions & 31 deletions backends/aoti/slim/c10/core/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

// Copy-pasted from c10/core/Scalar.h, but dropping SymScalar support

namespace standalone::c10 {
namespace executorch::backends::aoti::slim::c10 {

/**
* Scalar represents a 0-dimensional tensor which contains a single element.
Expand Down Expand Up @@ -86,22 +86,23 @@ class Scalar {
v.i = convert<int64_t, bool>(vv);
}

#define DEFINE_ACCESSOR(type, name) \
type to##name() const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
} else if (Tag::HAS_z == tag) { \
return checked_convert<type, standalone::c10::complex<double>>( \
v.z, #type); \
} \
if (Tag::HAS_b == tag) { \
return checked_convert<type, bool>(v.i, #type); \
} else if (Tag::HAS_i == tag) { \
return checked_convert<type, int64_t>(v.i, #type); \
} else if (Tag::HAS_u == tag) { \
return checked_convert<type, uint64_t>(v.u, #type); \
} \
STANDALONE_CHECK(false) \
#define DEFINE_ACCESSOR(type, name) \
type to##name() const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
} else if (Tag::HAS_z == tag) { \
return checked_convert< \
type, \
executorch::backends::aoti::slim::c10::complex<double>>(v.z, #type); \
} \
if (Tag::HAS_b == tag) { \
return checked_convert<type, bool>(v.i, #type); \
} else if (Tag::HAS_i == tag) { \
return checked_convert<type, int64_t>(v.i, #type); \
} else if (Tag::HAS_u == tag) { \
return checked_convert<type, uint64_t>(v.u, #type); \
} \
STANDALONE_CHECK(false) \
}

// TODO: Support ComplexHalf accessor
Expand Down Expand Up @@ -193,8 +194,9 @@ class Scalar {

template <
typename T,
typename std::enable_if_t<!standalone::c10::is_complex<T>::value, int> =
0>
typename std::enable_if_t<
!executorch::backends::aoti::slim::c10::is_complex<T>::value,
int> = 0>
bool equal(T num) const {
if (isComplex()) {
auto val = v.z;
Expand Down Expand Up @@ -223,7 +225,9 @@ class Scalar {

template <
typename T,
typename std::enable_if_t<standalone::c10::is_complex<T>::value, int> = 0>
typename std::enable_if_t<
executorch::backends::aoti::slim::c10::is_complex<T>::value,
int> = 0>
bool equal(T num) const {
if (isComplex()) {
return v.z == num;
Expand Down Expand Up @@ -257,20 +261,20 @@ class Scalar {
}
}

standalone::c10::ScalarType type() const {
executorch::backends::aoti::slim::c10::ScalarType type() const {
if (isComplex()) {
return standalone::c10::ScalarType::ComplexDouble;
return executorch::backends::aoti::slim::c10::ScalarType::ComplexDouble;
} else if (isFloatingPoint()) {
return standalone::c10::ScalarType::Double;
return executorch::backends::aoti::slim::c10::ScalarType::Double;
} else if (isIntegral(/*includeBool=*/false)) {
// Represent all integers as long, UNLESS it is unsigned and therefore
// unrepresentable as long
if (Tag::HAS_u == tag) {
return standalone::c10::ScalarType::UInt64;
return executorch::backends::aoti::slim::c10::ScalarType::UInt64;
}
return standalone::c10::ScalarType::Long;
return executorch::backends::aoti::slim::c10::ScalarType::Long;
} else if (isBoolean()) {
return standalone::c10::ScalarType::Bool;
return executorch::backends::aoti::slim::c10::ScalarType::Bool;
} else {
throw std::runtime_error("Unknown scalar type.");
}
Expand Down Expand Up @@ -313,7 +317,7 @@ class Scalar {
int64_t i;
// See Note [Meaning of HAS_u]
uint64_t u;
standalone::c10::complex<double> z;
executorch::backends::aoti::slim::c10::complex<double> z;
// NOLINTNEXTLINE(modernize-use-equals-default)
v_t() {} // default constructor
} v;
Expand All @@ -330,16 +334,18 @@ class Scalar {
template <
typename T,
typename std::enable_if_t<
!std::is_integral_v<T> && !standalone::c10::is_complex<T>::value,
!std::is_integral_v<T> &&
!executorch::backends::aoti::slim::c10::is_complex<T>::value,
bool>* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_d) {
v.d = convert<decltype(v.d), T>(vv);
}

template <
typename T,
typename std::enable_if_t<standalone::c10::is_complex<T>::value, bool>* =
nullptr>
typename std::enable_if_t<
executorch::backends::aoti::slim::c10::is_complex<T>::value,
bool>* = nullptr>
Scalar(T vv, bool) : tag(Tag::HAS_z) {
v.z = convert<decltype(v.z), T>(vv);
}
Expand All @@ -357,4 +363,4 @@ DEFINE_TO(uint32_t, UInt32)
DEFINE_TO(uint64_t, UInt64)
#undef DEFINE_TO

} // namespace standalone::c10
} // namespace executorch::backends::aoti::slim::c10
Loading
Loading