Skip to content

Commit bb1ba58

Browse files
Bam4dtimlacroix
andauthored
[Mistral] Mistral-7B-v0.1 support (#1196)
Co-authored-by: timlacroix <[email protected]>
1 parent 7bedab5 commit bb1ba58

File tree

13 files changed

+571
-25
lines changed

13 files changed

+571
-25
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ sentencepiece # Required for LLaMA tokenizer.
77
numpy
88
torch >= 2.0.0
99
transformers >= 4.33.1 # Required for Code Llama.
10-
xformers >= 0.0.21
10+
xformers >= 0.0.22
1111
fastapi
1212
uvicorn[standard]
1313
pydantic < 2 # Required for OpenAI server.

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,12 @@ def __init__(
187187
block_size: int,
188188
gpu_memory_utilization: float,
189189
swap_space: int,
190+
sliding_window: Optional[int] = None,
190191
) -> None:
191192
self.block_size = block_size
192193
self.gpu_memory_utilization = gpu_memory_utilization
193194
self.swap_space_bytes = swap_space * _GB
195+
self.sliding_window = sliding_window
194196
self._verify_args()
195197

196198
# Will be set after profiling.

vllm/core/block_manager.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,18 @@ def __init__(
6363
num_gpu_blocks: int,
6464
num_cpu_blocks: int,
6565
watermark: float = 0.01,
66+
sliding_window: Optional[int] = None,
6667
) -> None:
6768
self.block_size = block_size
6869
self.num_total_gpu_blocks = num_gpu_blocks
6970
self.num_total_cpu_blocks = num_cpu_blocks
71+
72+
self.block_sliding_window = None
73+
if sliding_window is not None:
74+
assert sliding_window % block_size == 0, (sliding_window,
75+
block_size)
76+
self.block_sliding_window = sliding_window // block_size
77+
7078
self.watermark = watermark
7179
assert watermark >= 0.0
7280

@@ -83,6 +91,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> bool:
8391
# the same prompt. This may not be true for preempted sequences.
8492
seq = seq_group.get_seqs()[0]
8593
num_required_blocks = len(seq.logical_token_blocks)
94+
if self.block_sliding_window is not None:
95+
num_required_blocks = min(num_required_blocks,
96+
self.block_sliding_window)
8697
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
8798
# Use watermark to avoid frequent cache eviction.
8899
return (num_free_gpu_blocks - num_required_blocks >=
@@ -95,8 +106,12 @@ def allocate(self, seq_group: SequenceGroup) -> None:
95106

96107
# Allocate new physical token blocks that will store the prompt tokens.
97108
block_table: BlockTable = []
98-
for _ in range(len(seq.logical_token_blocks)):
99-
block = self.gpu_allocator.allocate()
109+
for logical_idx in range(len(seq.logical_token_blocks)):
110+
if (self.block_sliding_window is not None
111+
and logical_idx >= self.block_sliding_window):
112+
block = block_table[logical_idx % self.block_sliding_window]
113+
else:
114+
block = self.gpu_allocator.allocate()
100115
# Set the reference counts of the token blocks.
101116
block.ref_count = seq_group.num_seqs()
102117
block_table.append(block)
@@ -118,11 +133,17 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
118133
block_table = self.block_tables[seq.seq_id]
119134

120135
if len(block_table) < len(logical_blocks):
121-
# The sequence has a new logical block.
122-
# Allocate a new physical block.
123-
block = self.gpu_allocator.allocate()
124-
block_table.append(block)
125-
return None
136+
if (self.block_sliding_window
137+
and len(block_table) >= self.block_sliding_window):
138+
# re-use a block
139+
block_table.append(block_table[len(block_table) %
140+
self.block_sliding_window])
141+
else:
142+
# The sequence has a new logical block.
143+
# Allocate a new physical block.
144+
block = self.gpu_allocator.allocate()
145+
block_table.append(block)
146+
return None
126147

127148
# We want to append the token to the last physical block.
128149
last_block = block_table[-1]
@@ -154,9 +175,7 @@ def _get_physical_blocks(
154175
for seq in seq_group.get_seqs():
155176
if seq.is_finished():
156177
continue
157-
block_table = self.block_tables[seq.seq_id]
158-
for block in block_table:
159-
blocks.add(block)
178+
blocks.update(self.block_tables[seq.seq_id])
160179
return list(blocks)
161180

162181
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
@@ -224,7 +243,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
224243
return block_number_mapping
225244

226245
def _free_block_table(self, block_table: BlockTable) -> None:
227-
for block in block_table:
246+
for block in set(block_table):
228247
if block.device == Device.GPU:
229248
self.gpu_allocator.free(block)
230249
else:

vllm/core/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
block_size=self.cache_config.block_size,
7474
num_gpu_blocks=self.cache_config.num_gpu_blocks,
7575
num_cpu_blocks=self.cache_config.num_cpu_blocks,
76-
)
76+
sliding_window=self.cache_config.sliding_window)
7777

7878
# TODO(zhuohan): Use deque instead of list for better performance.
7979
# Sequence groups in the WAITING state.

vllm/engine/arg_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ def create_engine_configs(
176176
self.download_dir, self.load_format,
177177
self.dtype, self.seed, self.revision,
178178
self.max_model_len, self.quantization)
179-
cache_config = CacheConfig(self.block_size,
180-
self.gpu_memory_utilization,
181-
self.swap_space)
179+
cache_config = CacheConfig(
180+
self.block_size, self.gpu_memory_utilization, self.swap_space,
181+
getattr(model_config.hf_config, 'sliding_window', None))
182182
parallel_config = ParallelConfig(self.pipeline_parallel_size,
183183
self.tensor_parallel_size,
184184
self.worker_use_ray)

vllm/engine/llm_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(
8686

8787
self.model_config = model_config
8888
self.cache_config = cache_config
89+
assert self.cache_config.sliding_window == getattr(
90+
self.model_config.hf_config, "sliding_window", None)
8991
self.parallel_config = parallel_config
9092
self.scheduler_config = scheduler_config
9193
self.log_stats = log_stats

vllm/model_executor/input_metadata.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple
1+
from typing import Dict, List, Optional, Tuple
22

33
import torch
44
from xformers.ops import AttentionBias
@@ -29,6 +29,7 @@ def __init__(
2929
context_lens: torch.Tensor,
3030
max_context_len: int,
3131
block_tables: torch.Tensor,
32+
sliding_window: Optional[int] = None,
3233
) -> None:
3334
self.seq_groups = seq_groups
3435
self.seq_data = seq_data
@@ -38,6 +39,24 @@ def __init__(
3839
self.max_context_len = max_context_len
3940
self.block_tables = block_tables
4041

42+
self.to_cache = None
43+
if sliding_window is not None:
44+
# We need to keep the positions of sliding windows within
45+
# the key / value tables, this is helpful to know which
46+
# elements we need to cache and where
47+
to_cache, start_idx = [], 0
48+
for prompt_len in self.prompt_lens:
49+
to_cache.extend(
50+
range(
51+
start_idx + max(0, prompt_len - sliding_window),
52+
start_idx + prompt_len,
53+
))
54+
start_idx += prompt_len
55+
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
56+
self.to_cache = torch.tensor(to_cache,
57+
dtype=torch.int32,
58+
device=self.slot_mapping.device)
59+
4160
self.num_prompts = len(prompt_lens)
4261
self.num_prompt_tokens = sum(prompt_lens)
4362
self.num_generation_tokens = context_lens.shape[0]

vllm/model_executor/layers/attention.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,14 @@ def __init__(self,
5858
num_heads: int,
5959
head_size: int,
6060
scale: float,
61-
num_kv_heads: Optional[int] = None) -> None:
61+
num_kv_heads: Optional[int] = None,
62+
sliding_window: Optional[int] = None) -> None:
6263
super().__init__()
6364
self.num_heads = num_heads
6465
self.head_size = head_size
6566
self.scale = float(scale)
6667
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
68+
self.sliding_window = sliding_window
6769

6870
assert self.num_heads % self.num_kv_heads == 0
6971
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -86,6 +88,8 @@ def set_attn_bias(
8688
return
8789
prompt_lens = input_metadata.prompt_lens
8890
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
91+
if self.sliding_window is not None:
92+
attn_bias = attn_bias.make_local_attention(self.sliding_window)
8993
input_metadata.attn_bias.append(attn_bias)
9094

9195
def multi_query_kv_attention(
@@ -223,12 +227,20 @@ def forward(
223227
if (num_valid_tokens > 0 and key_cache is not None
224228
and value_cache is not None):
225229
# The stride is 3 because the key and value are sliced from qkv.
230+
key_to_cache = key[:num_valid_tokens]
231+
value_to_cache = value[:num_valid_tokens]
232+
slot_mapping = input_metadata.slot_mapping
233+
if input_metadata.to_cache is not None:
234+
key_to_cache = key_to_cache[input_metadata.to_cache]
235+
value_to_cache = value_to_cache[input_metadata.to_cache]
236+
slot_mapping = slot_mapping[input_metadata.to_cache]
237+
226238
cache_ops.reshape_and_cache(
227-
key[:num_valid_tokens],
228-
value[:num_valid_tokens],
239+
key_to_cache,
240+
value_to_cache,
229241
key_cache,
230242
value_cache,
231-
input_metadata.slot_mapping,
243+
slot_mapping,
232244
)
233245

234246
if input_metadata.num_generation_tokens > 0:
@@ -262,8 +274,13 @@ def __init__(
262274
num_kv_heads: Optional[int] = None,
263275
is_neox_style: bool = True,
264276
rope_scaling: Optional[Dict[str, Any]] = None,
277+
sliding_window: Optional[int] = None,
265278
) -> None:
266-
super().__init__(num_heads, head_size, scale, num_kv_heads)
279+
super().__init__(num_heads,
280+
head_size,
281+
scale,
282+
num_kv_heads,
283+
sliding_window=sliding_window)
267284
if rope_scaling is None:
268285
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
269286
max_position, base,

vllm/model_executor/model_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"InternLMForCausalLM": InternLMForCausalLM,
2626
"LlamaForCausalLM": LlamaForCausalLM,
2727
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
28+
"MistralForCausalLM": MistralForCausalLM,
2829
"MPTForCausalLM": MPTForCausalLM,
2930
"OPTForCausalLM": OPTForCausalLM,
3031
"QWenLMHeadModel": QWenLMHeadModel,

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.model_executor.models.mpt import MPTForCausalLM
1313
from vllm.model_executor.models.opt import OPTForCausalLM
1414
from vllm.model_executor.models.qwen import QWenLMHeadModel
15+
from vllm.model_executor.models.mistral import MistralForCausalLM
1516

1617
__all__ = [
1718
"AquilaForCausalLM",
@@ -28,4 +29,5 @@
2829
"MPTForCausalLM",
2930
"OPTForCausalLM",
3031
"QWenLMHeadModel",
32+
"MistralForCausalLM",
3133
]

0 commit comments

Comments
 (0)