@@ -90,25 +90,6 @@ void call_reshape_and_cache(
90
90
});
91
91
}
92
92
93
- // Used by vectorization_utils to copy/convert one element
94
- template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
95
- struct CopyWithScaleOp {
96
- float scale;
97
-
98
- inline void operator ()(OutT& dst, const InT src) const {
99
- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
100
- dst = static_cast <OutT>(src);
101
- } else {
102
- float x = (float )src / scale;
103
- if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 ) {
104
- dst = static_cast <at::Float8_e4m3fn>(x);
105
- } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2 ) {
106
- dst = static_cast <at::Float8_e5m2>(x);
107
- }
108
- }
109
- }
110
- };
111
-
112
93
template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
113
94
void reshape_and_cache_flash_kernel (
114
95
const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
@@ -146,8 +127,8 @@ void reshape_and_cache_flash_kernel(
146
127
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *k_scale;
147
128
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto ) ? 0 .f : *v_scale;
148
129
149
- CopyWithScaleOp<cache_t , scalar_t , kv_dt> k_op{k_scale_val};
150
- CopyWithScaleOp<cache_t , scalar_t , kv_dt> v_op{v_scale_val};
130
+ fp8:: CopyWithScaleOp<cache_t , scalar_t , kv_dt> k_op{k_scale_val};
131
+ fp8:: CopyWithScaleOp<cache_t , scalar_t , kv_dt> v_op{v_scale_val};
151
132
fp8::scaled_convert_vec (key_src, key_dst, n, local_idx, local_range, k_op);
152
133
fp8::scaled_convert_vec (value_src, value_dst, n, local_idx, local_range,
153
134
v_op);
0 commit comments