Skip to content

Commit 059a372

Browse files
committed
remove redefine struct
Signed-off-by: Zhu, Zufang <[email protected]>
1 parent 5bbb0cc commit 059a372

File tree

1 file changed

+2
-21
lines changed

1 file changed

+2
-21
lines changed

csrc/xpu/cache.cpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,6 @@ void call_reshape_and_cache(
9090
});
9191
}
9292

93-
// Used by vectorization_utils to copy/convert one element
94-
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
95-
struct CopyWithScaleOp {
96-
float scale;
97-
98-
inline void operator()(OutT& dst, const InT src) const {
99-
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
100-
dst = static_cast<OutT>(src);
101-
} else {
102-
float x = (float)src / scale;
103-
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
104-
dst = static_cast<at::Float8_e4m3fn>(x);
105-
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
106-
dst = static_cast<at::Float8_e5m2>(x);
107-
}
108-
}
109-
}
110-
};
111-
11293
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
11394
void reshape_and_cache_flash_kernel(
11495
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
@@ -146,8 +127,8 @@ void reshape_and_cache_flash_kernel(
146127
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
147128
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
148129

149-
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
150-
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
130+
fp8::CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
131+
fp8::CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
151132
fp8::scaled_convert_vec(key_src, key_dst, n, local_idx, local_range, k_op);
152133
fp8::scaled_convert_vec(value_src, value_dst, n, local_idx, local_range,
153134
v_op);

0 commit comments

Comments
 (0)