Skip to content

Commit d87745b

Browse files
committed
Add numeric_conversion_hip header
1 parent 66ba4ed commit d87745b

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Compatibility header for CUTLASS numeric conversion on HIP/ROCm
2+
// This provides a minimal subset of CUTLASS functionality needed for TurboDiffusion
3+
4+
#pragma once
5+
6+
#include <hip/hip_runtime.h>
7+
#include <cstdint>
8+
9+
namespace cutlass {
10+
11+
// FloatRoundStyle enum (subset of CUTLASS)
12+
enum class FloatRoundStyle {
13+
round_to_nearest = 0,
14+
round_toward_zero = 1,
15+
round_toward_infinity = 2,
16+
round_toward_neg_infinity = 3,
17+
};
18+
19+
// NumericConverter template - provides float to int8 conversion with rounding
20+
template <typename To, typename From, FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
21+
struct NumericConverter {
22+
__device__ __host__ __forceinline__
23+
To operator()(From const& val) const {
24+
return static_cast<To>(val);
25+
}
26+
};
27+
28+
// Specialization for float to int8_t with round_to_nearest
29+
template <>
30+
struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
31+
__device__ __host__ __forceinline__
32+
int8_t operator()(float val) const {
33+
// Round to nearest and clamp to int8 range [-128, 127]
34+
val = fmaxf(-128.0f, fminf(127.0f, rintf(val)));
35+
return static_cast<int8_t>(val);
36+
}
37+
};
38+
39+
// Specialization for float to int8_t with round_toward_zero
40+
template <>
41+
struct NumericConverter<int8_t, float, FloatRoundStyle::round_toward_zero> {
42+
__device__ __host__ __forceinline__
43+
int8_t operator()(float val) const {
44+
// Truncate and clamp to int8 range [-128, 127]
45+
val = fmaxf(-128.0f, fminf(127.0f, truncf(val)));
46+
return static_cast<int8_t>(val);
47+
}
48+
};
49+
50+
} // namespace cutlass
51+

turbodiffusion/ops/quant/quant_hip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
#include <hip/hip_runtime.h>
2020
#include <hip/hip_runtime.h>
21-
#include "cutlass/numeric_conversion_hip.h"
21+
#include "common/numeric_conversion_hip.hpp"
2222

2323
#include "common/load.hpp"
2424
#include "common/store_hip.hpp"

0 commit comments

Comments
 (0)