1- #include " flash_infer_handler .h"
1+ #include " scale_attn_handler .h"
22
33#include < torch/torch.h>
44
77
88namespace llm {
99
10- FlashInferHandler::FlashInferHandler (float scale,
10+ ScaleAttnHandler::ScaleAttnHandler (float scale,
1111 int64_t rotary_dim,
1212 int64_t max_position,
1313 float rope_scaling,
@@ -17,13 +17,13 @@ FlashInferHandler::FlashInferHandler(float scale,
1717 LOG (FATAL) << " Not implemented yet" ;
1818}
1919
20- FlashInferHandler::FlashInferHandler (
20+ ScaleAttnHandler::ScaleAttnHandler (
2121 float scale,
2222 torch::optional<torch::Tensor> alibi_slopes)
2323 : scale_(scale), alibi_slopes_(alibi_slopes) {}
2424
2525// batch prefill for attention, optimized for prefill stage
26- void FlashInferHandler ::batch_prefill (
26+ void ScaleAttnHandler ::batch_prefill (
2727 const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
2828 const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
2929 const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
@@ -36,7 +36,7 @@ void FlashInferHandler::batch_prefill(
3636
3737// batch decode for attention, optimized for decode stage
3838// support multiple queries: one sequence with multiple query tokens
39- void FlashInferHandler ::batch_decode (
39+ void ScaleAttnHandler ::batch_decode (
4040 const torch::Tensor& query, // [n_tokens, n_heads, head_dim]
4141 const KVCache& kv_cache, // where to retrieval key and value
4242 const InputParameters& input_params, // input paras used for attention
@@ -47,7 +47,7 @@ void FlashInferHandler::batch_decode(
4747}
4848
4949// append key and value to kv_cache
50- void FlashInferHandler ::append_kv_cache (
50+ void ScaleAttnHandler ::append_kv_cache (
5151 KVCache& kv_cache, // where to store key and value
5252 const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
5353 const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
0 commit comments