44Run `pytest tests/models/test_mamba.py`.
55"""
66import pytest
7+ import torch
78from transformers import AutoModelForCausalLM , AutoTokenizer
89
910from vllm .engine .arg_utils import EngineArgs
1011from vllm .sampling_params import SamplingParams
1112
1213from ...utils import check_outputs_equal
1314
14- MODELS = ["state-spaces/mamba-130m-hf" , "tiiuae/falcon-mamba-tiny-dev" ]
15+ MODELS = [
16+ "state-spaces/mamba-130m-hf" ,
17+ "tiiuae/falcon-mamba-tiny-dev" ,
18+ # TODO: Compare to a Mamba2 model. The HF transformers implementation of
19+ # Mamba2 is buggy for Codestral as it doesn't handle n_groups.
20+ # See https://github.com/huggingface/transformers/pull/35943
21+ # "mistralai/Mamba-Codestral-7B-v0.1",
22+ ]
1523
1624
1725# Use lower-level interfaces to create this greedy generator, as mamba will
@@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens):
2129 tokenizer = AutoTokenizer .from_pretrained (model_name )
2230 model = AutoModelForCausalLM .from_pretrained (model_name )
2331
32+ # Set the device (GPU if available, else CPU)
33+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
34+ model .to (device )
35+
2436 # Generate texts from the prompts
2537 outputs = []
2638 for prompt in example_prompts :
@@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens):
2941 input_ids = inputs ["input_ids" ].to (model .device )
3042
3143 # Generate text using the model's generate method directly
32- generated_ids = model .generate (input_ids , max_new_tokens = max_tokens )
44+ generated_ids = model .generate (input_ids ,
45+ max_new_tokens = max_tokens ,
46+ do_sample = False )
3347 generated_text = tokenizer .decode (generated_ids [0 ],
3448 skip_special_tokens = True )
3549
@@ -50,7 +64,8 @@ def test_models(
5064) -> None :
5165 hf_outputs = generate_greedy (model , example_prompts , max_tokens )
5266
53- with vllm_runner (model , dtype = dtype ) as vllm_model :
67+ # Set max_num_seqs to keep Codestral from going OOM at fp32
68+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
5469 vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
5570
5671 # This test is for verifying whether the model's extra_repr
@@ -81,7 +96,7 @@ def test_batching(
8196) -> None :
8297 # To pass the small model tests, we need full precision.
8398 for_loop_outputs = []
84- with vllm_runner (model , dtype = dtype ) as vllm_model :
99+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
85100 for prompt in example_prompts :
86101 for_loop_outputs .append (
87102 vllm_model .generate_greedy ([prompt ], max_tokens )[0 ])
@@ -165,20 +180,22 @@ def test_parallel_sampling(
165180 max_tokens : int ,
166181) -> None :
167182
168- with vllm_runner (model , dtype = dtype ) as vllm_model :
183+ # Numerical differences produce slightly different output for these
184+ if 'state-spaces' in model :
185+ example_prompts .pop (0 )
186+ example_prompts .pop (0 )
187+ example_prompts .pop (0 )
188+
189+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
169190 for_loop_outputs = []
170191 for _ in range (10 ):
171192 for_loop_outputs .append (
172- # using example_prompts index 1 instead of 0 since with 0 the
173- # logprobs get really close and the test doesn't pass
174- vllm_model .generate_greedy ([example_prompts [1 ]], max_tokens )
175- [0 ])
193+ vllm_model .generate_greedy (example_prompts , max_tokens )[0 ])
176194 sampling_params = SamplingParams (n = 10 ,
177195 temperature = 0.001 ,
178196 seed = 0 ,
179197 max_tokens = max_tokens )
180- n_lt_1_outputs = vllm_model .generate ([example_prompts [1 ]],
181- sampling_params )
198+ n_lt_1_outputs = vllm_model .generate (example_prompts , sampling_params )
182199 token_ids , texts = n_lt_1_outputs [0 ]
183200 n_lt_1_outputs = [(token_id , text )
184201 for token_id , text in zip (token_ids , texts )]
@@ -232,7 +249,7 @@ def test_models_preemption_recompute(
232249 # Tests that outputs are identical with and w/o preemtions (recompute)
233250 assert dtype == "float"
234251
235- with vllm_runner (model , dtype = dtype ) as vllm_model :
252+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
236253 vllm_model .model .llm_engine .scheduler [
237254 0 ].ENABLE_ARTIFICIAL_PREEMPT = True
238255 preempt_vllm_outputs = vllm_model .generate_greedy (
@@ -283,7 +300,7 @@ def test_state_cleanup(
283300 # This test is for verifying that the Mamba state is cleaned up between
284301 # steps, If its not cleaned, an error would be expected.
285302 try :
286- with vllm_runner (model , dtype = dtype ) as vllm_model :
303+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
287304 for _ in range (10 ):
288305 vllm_model .generate_greedy ([example_prompts [0 ]] * 100 , 1 )
289306 except ValueError :
0 commit comments