|
| 1 | +// Compatibility header for CUTLASS numeric conversion on HIP/ROCm |
| 2 | +// This provides a minimal subset of CUTLASS functionality needed for TurboDiffusion |
| 3 | + |
| 4 | +#pragma once |
| 5 | + |
| 6 | +#include <hip/hip_runtime.h> |
| 7 | +#include <cstdint> |
| 8 | + |
| 9 | +namespace cutlass { |
| 10 | + |
| 11 | +// FloatRoundStyle enum (subset of CUTLASS) |
| 12 | +enum class FloatRoundStyle { |
| 13 | + round_to_nearest = 0, |
| 14 | + round_toward_zero = 1, |
| 15 | + round_toward_infinity = 2, |
| 16 | + round_toward_neg_infinity = 3, |
| 17 | +}; |
| 18 | + |
| 19 | +// NumericConverter template - provides float to int8 conversion with rounding |
| 20 | +template <typename To, typename From, FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> |
| 21 | +struct NumericConverter { |
| 22 | + __device__ __host__ __forceinline__ |
| 23 | + To operator()(From const& val) const { |
| 24 | + return static_cast<To>(val); |
| 25 | + } |
| 26 | +}; |
| 27 | + |
| 28 | +// Specialization for float to int8_t with round_to_nearest |
| 29 | +template <> |
| 30 | +struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> { |
| 31 | + __device__ __host__ __forceinline__ |
| 32 | + int8_t operator()(float val) const { |
| 33 | + // Round to nearest and clamp to int8 range [-128, 127] |
| 34 | + val = fmaxf(-128.0f, fminf(127.0f, rintf(val))); |
| 35 | + return static_cast<int8_t>(val); |
| 36 | + } |
| 37 | +}; |
| 38 | + |
| 39 | +// Specialization for float to int8_t with round_toward_zero |
| 40 | +template <> |
| 41 | +struct NumericConverter<int8_t, float, FloatRoundStyle::round_toward_zero> { |
| 42 | + __device__ __host__ __forceinline__ |
| 43 | + int8_t operator()(float val) const { |
| 44 | + // Truncate and clamp to int8 range [-128, 127] |
| 45 | + val = fmaxf(-128.0f, fminf(127.0f, truncf(val))); |
| 46 | + return static_cast<int8_t>(val); |
| 47 | + } |
| 48 | +}; |
| 49 | + |
| 50 | +} // namespace cutlass |
| 51 | + |
0 commit comments