Skip to content

Commit a8e98ae

Browse files
authored
Fix Mistral model (#1220)
1 parent bb1ba58 commit a8e98ae

File tree

4 files changed

+27
-14
lines changed

4 files changed

+27
-14
lines changed

vllm/model_executor/models/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import torch
3131
from torch import nn
32-
from vllm.transformers_utils.configs.mistral import MistralConfig
3332

3433
from vllm.model_executor.input_metadata import InputMetadata
3534
from vllm.model_executor.layers.activation import SiluAndMul
@@ -46,6 +45,7 @@
4645
convert_pyslice_to_tensor, hf_model_weights_iterator,
4746
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
4847
from vllm.sequence import SamplerOutput
48+
from vllm.transformers_utils.configs.mistral import MistralConfig
4949

5050
KVCache = Tuple[torch.Tensor, torch.Tensor]
5151

vllm/transformers_utils/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
def get_config(model: str,
1818
trust_remote_code: bool,
1919
revision: Optional[str] = None) -> PretrainedConfig:
20+
# NOTE: Because the Mistral model in HF hub does not have
21+
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
22+
# config. Instead, we use `MistralConfig` directly.
23+
# NOTE: This is a hack. This does not work for local models.
24+
# FIXME: Remove this once the Mistral model is available in the stable
25+
# version of HF transformers.
26+
if "mistral" in model.lower():
27+
return MistralConfig.from_pretrained(model, revision=revision)
28+
2029
try:
2130
config = AutoConfig.from_pretrained(
2231
model, trust_remote_code=trust_remote_code, revision=revision)

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
77
# `FalconConfig` class from the official HuggingFace transformers library.
88
from vllm.transformers_utils.configs.falcon import RWConfig
9+
from vllm.transformers_utils.configs.mistral import MistralConfig
910

1011
__all__ = [
1112
"MPTConfig",
1213
"BaiChuanConfig",
1314
"AquilaConfig",
1415
"QWenConfig",
1516
"RWConfig",
17+
"MistralConfig",
1618
]

vllm/worker/worker.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
# self.init_cache_engine().
4343
self.cache_config = None
4444
self.block_size = None
45+
self.sliding_window = None
4546
self.cache_engine = None
4647
self.cache_events = None
4748
self.gpu_cache = None
@@ -136,10 +137,13 @@ def profile_num_available_blocks(
136137
def init_cache_engine(self, cache_config: CacheConfig) -> None:
137138
self.cache_config = cache_config
138139
self.block_size = cache_config.block_size
140+
self.sliding_window = cache_config.sliding_window
139141

140-
max_seq_len = min(self.scheduler_config.max_model_len,
141-
cache_config.sliding_window or float("inf"))
142-
142+
if self.sliding_window is None:
143+
max_seq_len = self.scheduler_config.max_model_len
144+
else:
145+
max_seq_len = min(self.scheduler_config.max_model_len,
146+
self.sliding_window)
143147
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
144148

145149
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
@@ -151,6 +155,8 @@ def _prepare_inputs(
151155
self,
152156
seq_group_metadata_list: List[SequenceGroupMetadata],
153157
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
158+
assert self.block_size is not None
159+
154160
seq_groups: List[Tuple[List[int], SamplingParams]] = []
155161
input_tokens: List[int] = []
156162
input_positions: List[int] = []
@@ -193,9 +199,6 @@ def _prepare_inputs(
193199
slot = block_number * self.block_size + block_offset
194200
slot_mapping.append(slot)
195201

196-
sliding_window = getattr(self.model_config.hf_config, "sliding_window",
197-
float("inf"))
198-
199202
# Add generation tokens.
200203
max_context_len = 0
201204
max_num_blocks_per_seq = 0
@@ -216,8 +219,8 @@ def _prepare_inputs(
216219

217220
context_len = seq_data.get_len()
218221
position = context_len - 1
219-
if sliding_window:
220-
context_len = min(context_len, sliding_window)
222+
if self.sliding_window is not None:
223+
context_len = min(context_len, self.sliding_window)
221224
input_positions.append(position)
222225

223226
block_table = seq_group_metadata.block_tables[seq_id]
@@ -232,10 +235,9 @@ def _prepare_inputs(
232235
slot = block_number * self.block_size + block_offset
233236
slot_mapping.append(slot)
234237

235-
if sliding_window:
236-
assert self.cache_config is not None
237-
sliding_window_blocks = (sliding_window //
238-
self.cache_config.block_size)
238+
if self.sliding_window is not None:
239+
sliding_window_blocks = (self.sliding_window //
240+
self.block_size)
239241
block_table = block_table[-sliding_window_blocks:]
240242
generation_block_tables.append(block_table)
241243

@@ -277,7 +279,7 @@ def _prepare_inputs(
277279
context_lens=context_lens_tensor,
278280
max_context_len=max_context_len,
279281
block_tables=block_tables_tensor,
280-
sliding_window=sliding_window,
282+
sliding_window=self.sliding_window,
281283
)
282284
return tokens_tensor, positions_tensor, input_metadata
283285

0 commit comments

Comments
 (0)