-
Notifications
You must be signed in to change notification settings - Fork 996
Expand file tree
/
Copy pathpatch_kv_cache_interface.py
More file actions
142 lines (118 loc) · 5.97 KB
/
patch_kv_cache_interface.py
File metadata and controls
142 lines (118 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import vllm.v1.kv_cache_interface
from typing_extensions import Self
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@dataclass(frozen=True)
class AscendMLAAttentionSpec(MLAAttentionSpec):
"""MLAAttentionSpec extended to support DSA models, with optional Sparse C8 support.
When Sparse C8 is enabled, the KV cache tuple changes from
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: bfloat16)
to
(kv_cache[0]: bfloat16, kv_cache[1]: bfloat16, kv_cache[2]: int8, kv_cache[3]: float16).
The semantic meaning of each KV cache entry is as follows:
1. kv_cache[0] stores kv_lora.
2. kv_cache[1] stores k_rope.
3. kv_cache[2] stores the key tensor from the indexer module.
4. kv_cache[3] stores the key scale tensor from the indexer module,
and exists only when Sparse C8 is enabled.
The main changes are as follows:
1. The key tensor from the indexer module stored in kv_cache[2] is
converted from bf16 to int8 to reduce memory usage. It is then
processed with int8 precision in Lightning_indexer computation
to improve computational efficiency.
2. The quantization scale of the key tensor in the indexer module
must also be stored for the Lightning_indexer_quant operator,
and is therefore saved in kv_cache[3].
"""
sparse_head_dim: tuple[int, ...] | None = None
cache_sparse_c8: bool = False
c8_k_cache_dtype: torch.dtype = torch.int8
c8_k_scale_cache_dtype: torch.dtype = torch.float16
@property
def page_size_bytes(self) -> int:
if self.cache_sparse_c8:
assert self.sparse_head_dim is not None
assert len(self.sparse_head_dim) == 3
num_heads_per_page = self.block_size * self.num_kv_heads
# kv_cache[0]: bfloat16, kv_cache[1]: bfloat16
kv_lora_rank, qk_rope_head_dim = self.sparse_head_dim[:2]
k_pe_nope_bytes = num_heads_per_page * (kv_lora_rank + qk_rope_head_dim) * get_dtype_size(self.dtype)
# kv_cache[2]: int8
index_head_dim = self.sparse_head_dim[-1]
indexer_k_bytes = num_heads_per_page * index_head_dim * get_dtype_size(self.c8_k_cache_dtype)
# kv_cache[3]: float16
# since the scale is stored per token, head_dim is set to 1.
index_scale_head_dim = 1
indexer_k_scale_bytes = (
num_heads_per_page * index_scale_head_dim * get_dtype_size(self.c8_k_scale_cache_dtype)
)
return k_pe_nope_bytes + indexer_k_bytes + indexer_k_scale_bytes
return self.block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype)
@property
def sparse_kv_cache_ratio(self) -> tuple[float, float, float, float | None]:
"""
Compute the relative byte share of each KV cache entry.
Returns:
A tuple containing the ratios for:
- kv_cache[0]
- kv_cache[1]
- kv_cache[2]
- kv_cache[3] (None if Sparse C8 is disabled)
"""
assert self.sparse_head_dim is not None
def get_sparse_head_dim_virtual() -> tuple[int, int, int, int]:
assert self.sparse_head_dim is not None
assert self.cache_sparse_c8 is True
kv_lora_rank, qk_rope_head_dim, index_k_head_dim = self.sparse_head_dim
factor = get_dtype_size(self.dtype) // get_dtype_size(self.c8_k_cache_dtype)
index_k_head_dim_virtual = index_k_head_dim // factor
assert get_dtype_size(self.dtype) == get_dtype_size(self.c8_k_scale_cache_dtype)
index_k_scale_head_dim_virtual = 1
return (
kv_lora_rank,
qk_rope_head_dim,
index_k_head_dim_virtual,
index_k_scale_head_dim_virtual,
)
if self.cache_sparse_c8:
virtual_dims = get_sparse_head_dim_virtual()
total_virtual_head_dim = sum(virtual_dims)
return (
total_virtual_head_dim / virtual_dims[0], # kv_cache[0]
total_virtual_head_dim / virtual_dims[1], # kv_cache[1]
total_virtual_head_dim / virtual_dims[2], # kv_cache[2]
total_virtual_head_dim / virtual_dims[3], # kv_cache[3]
)
return (
self.head_size / self.sparse_head_dim[0], # kv_cache[0]
self.head_size / self.sparse_head_dim[1], # kv_cache[1]
self.head_size / self.sparse_head_dim[2], # kv_cache[2]
None, # kv_cache[3] does not exist
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same quantization method."
)
cache_sparse_c8_set = set(spec.cache_sparse_c8 for spec in specs)
assert len(cache_sparse_c8_set) == 1, (
"All attention layers in the same KV cache group must use the same sparse C8 setting."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
sparse_head_dim=specs[0].sparse_head_dim,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
cache_sparse_c8=cache_sparse_c8_set.pop(),
)
vllm.v1.kv_cache_interface.MLAAttentionSpec = AscendMLAAttentionSpec