Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 38 additions & 44 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,21 @@ def test_models(
else:
hf_outputs = None

if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None

if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None

Expand All @@ -137,7 +137,7 @@ def test_models(
)


@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("model", SSM_MODELS[0] + HYBRID_MODELS[0])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_batching(
Expand Down Expand Up @@ -402,24 +402,21 @@ def test_full_cuda_graph(
else:
hf_outputs = None

if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
compilation_config={'full_cuda_graph': True},
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
m.setenv("VLLM_USE_V1", "0")
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None

with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
compilation_config={'full_cuda_graph': True},
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

if hf_outputs is not None and vllm_v0_outputs is not None:
check_logprobs_close(
Expand Down Expand Up @@ -466,24 +463,21 @@ def test_fp32_state(
else:
hf_outputs = None

with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32",
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32",
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
Expand Down
5 changes: 0 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,11 +1452,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# V1 mamba models are unoptimized.
if model_config.has_inner_state and _warn_or_fallback(
feature_name="Mamba"):
return False

# No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills
!= SchedulerConfig.max_num_partial_prefills
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,4 +417,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"GptOssForCausalLM": GptOssForCausalLMConfig,
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep fine for this PR but I think this line makes vLLM not that plugable to new models.

}