@@ -47,6 +47,36 @@ namespace executorch {
4747namespace runtime {
4848namespace etensor {
4949
50+ // Placing a bunch of unused dtypes here as our macros don't make it easy
51+ // to skip scalar types defined in aten that we dont have.
52+ namespace unused_dtype {
53+ struct alignas (1 ) Float8_e5m2 {
54+ uint8_t x;
55+ using underlying = uint8_t ;
56+ Float8_e5m2 () = default ;
57+ explicit Float8_e5m2 (uint8_t val) : x (val) {}
58+ };
59+ struct alignas (1 ) Float8_e4m3fn {
60+ uint8_t x;
61+ using underlying = uint8_t ;
62+ Float8_e4m3fn () = default ;
63+ explicit Float8_e4m3fn (uint8_t val) : x (val) {}
64+ };
65+ struct alignas (1 ) Float8_e5m2fnuz {
66+ uint8_t x;
67+ using underlying = uint8_t ;
68+ Float8_e5m2fnuz () = default ;
69+ explicit Float8_e5m2fnuz (uint8_t val) : x (val) {}
70+ };
71+ struct alignas (1 ) Float8_e4m3fnuz {
72+ uint8_t x;
73+ using underlying = uint8_t ;
74+ Float8_e4m3fnuz () = default ;
75+ explicit Float8_e4m3fnuz (uint8_t val) : x (val) {}
76+ };
77+
78+ } // namespace unused_dtype
79+
5080/* *
5181 * Calls the provided macro on every ScalarType, providing the C type and the
5282 * ScalarType name to each call.
@@ -59,30 +89,42 @@ namespace etensor {
5989 * @param _ A macro that takes two parameters: the name of a C type, and the
6090 * name of the corresponding ScalarType enumerator.
6191 */
62- #define ET_FORALL_SCALAR_TYPES (_ ) \
63- _ (uint8_t , Byte) /* 0 */ \
64- _ (int8_t , Char) /* 1 */ \
65- _ (int16_t , Short) /* 2 */ \
66- _ (int32_t , Int) /* 3 */ \
67- _ (int64_t , Long) /* 4 */ \
68- _ (::torch::executor::Half, Half) /* 5 */ \
69- _ (float , Float) /* 6 */ \
70- _ (double , Double) /* 7 */ \
71- _ (::torch::executor::complex <::torch::executor::Half>, ComplexHalf) /* 8 */ \
72- _ (::torch::executor::complex <float >, ComplexFloat) /* 9 */ \
73- _ (::torch::executor::complex <double >, ComplexDouble) /* 10 */ \
74- _ (bool , Bool) /* 11 */ \
75- _ (::torch::executor::qint8, QInt8) /* 12 */ \
76- _ (::torch::executor::quint8, QUInt8) /* 13 */ \
77- _ (::torch::executor::qint32, QInt32) /* 14 */ \
78- _ (::torch::executor::BFloat16, BFloat16) /* 15 */ \
79- _ (::torch::executor::quint4x2, QUInt4x2) /* 16 */ \
80- _ (::torch::executor::quint2x4, QUInt2x4) /* 17 */ \
81- _ (::torch::executor::bits1x8, Bits1x8) /* 18 */ \
82- _ (::torch::executor::bits2x4, Bits2x4) /* 19 */ \
83- _ (::torch::executor::bits4x2, Bits4x2) /* 20 */ \
84- _ (::torch::executor::bits8, Bits8) /* 21 */ \
85- _ (::torch::executor::bits16, Bits16) /* 22 */
92+ #define ET_FORALL_SCALAR_TYPES (_ ) \
93+ _ (uint8_t , Byte) /* 0 */ \
94+ _ (int8_t , Char) /* 1 */ \
95+ _ (int16_t , Short) /* 2 */ \
96+ _ (int32_t , Int) /* 3 */ \
97+ _ (int64_t , Long) /* 4 */ \
98+ _ (::executorch::runtime::etensor::Half, Half) /* 5 */ \
99+ _ (float , Float) /* 6 */ \
100+ _ (double , Double) /* 7 */ \
101+ _ (::executorch::runtime::etensor::complex <::torch::executor::Half>, \
102+ ComplexHalf) /* 8 */ \
103+ _ (::executorch::runtime::etensor::complex <float >, ComplexFloat) /* 9 */ \
104+ _ (::executorch::runtime::etensor::complex <double >, ComplexDouble) /* 10 */ \
105+ _ (bool , Bool) /* 11 */ \
106+ _ (::executorch::runtime::etensor::qint8, QInt8) /* 12 */ \
107+ _ (::executorch::runtime::etensor::quint8, QUInt8) /* 13 */ \
108+ _ (::executorch::runtime::etensor::qint32, QInt32) /* 14 */ \
109+ _ (::executorch::runtime::etensor::BFloat16, BFloat16) /* 15 */ \
110+ _ (::executorch::runtime::etensor::quint4x2, QUInt4x2) /* 16 */ \
111+ _ (::executorch::runtime::etensor::quint2x4, QUInt2x4) /* 17 */ \
112+ _ (::executorch::runtime::etensor::bits1x8, Bits1x8) /* 18 */ \
113+ _ (::executorch::runtime::etensor::bits2x4, Bits2x4) /* 19 */ \
114+ _ (::executorch::runtime::etensor::bits4x2, Bits4x2) /* 20 */ \
115+ _ (::executorch::runtime::etensor::bits8, Bits8) /* 21 */ \
116+ _ (::executorch::runtime::etensor::bits16, Bits16) /* 22 */ \
117+ _ (::executorch::runtime::etensor::unused_dtype::Float8_e5m2, \
118+ Float8_e5m2) /* 23 */ \
119+ _ (::executorch::runtime::etensor::unused_dtype::Float8_e4m3fn, \
120+ Float8_e4m3fn) /* 24 */ \
121+ _ (::executorch::runtime::etensor::unused_dtype::Float8_e5m2fnuz, \
122+ Float8_e5m2fnuz) /* 25 */ \
123+ _ (::executorch::runtime::etensor::unused_dtype::Float8_e4m3fnuz, \
124+ Float8_e4m3fnuz) /* 26 */ \
125+ _ (uint16_t , UInt16) /* 27 */ \
126+ _ (uint32_t , UInt32) /* 28 */ \
127+ _ (uint64_t , UInt64) /* 29 */
86128
87129/* *
88130 * Data types (dtypes) that can be used as element types in ETensors.
0 commit comments