Skip to content

Commit 23ec72f

Browse files
authored
[CI/Build][REDO] Add is_quant_method_supported to control quantization test configurations (#5466)
1 parent c2637a6 commit 23ec72f

File tree

8 files changed

+32
-71
lines changed

8 files changed

+32
-71
lines changed

tests/models/test_aqlm.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,8 @@
44
"""
55

66
import pytest
7-
import torch
87

9-
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
10-
11-
aqlm_not_supported = True
12-
13-
if torch.cuda.is_available():
14-
capability = torch.cuda.get_device_capability()
15-
capability = capability[0] * 10 + capability[1]
16-
aqlm_not_supported = (capability <
17-
QUANTIZATION_METHODS["aqlm"].get_min_capability())
8+
from tests.quantization.utils import is_quant_method_supported
189

1910
# In this test we hardcode prompts and generations for the model so we don't
2011
# need to require the AQLM package as a dependency
@@ -67,7 +58,7 @@
6758
]
6859

6960

70-
@pytest.mark.skipif(aqlm_not_supported,
61+
@pytest.mark.skipif(not is_quant_method_supported("aqlm"),
7162
reason="AQLM is not supported on this GPU type.")
7263
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
7364
@pytest.mark.parametrize("dtype", ["half"])

tests/models/test_fp8.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import torch
99
from transformers import AutoTokenizer
1010

11+
from tests.quantization.utils import is_quant_method_supported
1112
from vllm import LLM, SamplingParams
12-
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1313

1414
os.environ["TOKENIZERS_PARALLELISM"] = "true"
1515

@@ -67,16 +67,8 @@
6767
},
6868
}
6969

70-
fp8_not_supported = True
7170

72-
if torch.cuda.is_available():
73-
capability = torch.cuda.get_device_capability()
74-
capability = capability[0] * 10 + capability[1]
75-
fp8_not_supported = (capability <
76-
QUANTIZATION_METHODS["fp8"].get_min_capability())
77-
78-
79-
@pytest.mark.skipif(fp8_not_supported,
71+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
8072
reason="fp8 is not supported on this GPU type.")
8173
@pytest.mark.parametrize("model_name", MODELS)
8274
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])

tests/models/test_gptq_marlin.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
import os
1212

1313
import pytest
14-
import torch
1514

16-
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
15+
from tests.quantization.utils import is_quant_method_supported
1716
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
1817

1918
from .utils import check_logprobs_close
@@ -22,14 +21,6 @@
2221

2322
MAX_MODEL_LEN = 1024
2423

25-
gptq_marlin_not_supported = True
26-
27-
if torch.cuda.is_available():
28-
capability = torch.cuda.get_device_capability()
29-
capability = capability[0] * 10 + capability[1]
30-
gptq_marlin_not_supported = (
31-
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
32-
3324
MODELS = [
3425
# act_order==False, group_size=channelwise
3526
("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
@@ -53,7 +44,7 @@
5344

5445

5546
@pytest.mark.flaky(reruns=3)
56-
@pytest.mark.skipif(gptq_marlin_not_supported,
47+
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
5748
reason="gptq_marlin is not supported on this GPU type.")
5849
@pytest.mark.parametrize("model", MODELS)
5950
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])

tests/models/test_gptq_marlin_24.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,9 @@
99
from dataclasses import dataclass
1010

1111
import pytest
12-
import torch
1312

1413
from tests.models.utils import check_logprobs_close
15-
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
16-
17-
marlin_not_supported = True
18-
19-
if torch.cuda.is_available():
20-
capability = torch.cuda.get_device_capability()
21-
capability = capability[0] * 10 + capability[1]
22-
marlin_not_supported = (
23-
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
14+
from tests.quantization.utils import is_quant_method_supported
2415

2516

2617
@dataclass
@@ -47,7 +38,7 @@ class ModelPair:
4738

4839

4940
@pytest.mark.flaky(reruns=2)
50-
@pytest.mark.skipif(marlin_not_supported,
41+
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
5142
reason="Marlin24 is not supported on this GPU type.")
5243
@pytest.mark.parametrize("model_pair", model_pairs)
5344
@pytest.mark.parametrize("dtype", ["half"])

tests/models/test_marlin.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,11 @@
1313
from dataclasses import dataclass
1414

1515
import pytest
16-
import torch
1716

18-
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
17+
from tests.quantization.utils import is_quant_method_supported
1918

2019
from .utils import check_logprobs_close
2120

22-
marlin_not_supported = True
23-
24-
if torch.cuda.is_available():
25-
capability = torch.cuda.get_device_capability()
26-
capability = capability[0] * 10 + capability[1]
27-
marlin_not_supported = (
28-
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
29-
3021

3122
@dataclass
3223
class ModelPair:
@@ -45,7 +36,7 @@ class ModelPair:
4536

4637

4738
@pytest.mark.flaky(reruns=2)
48-
@pytest.mark.skipif(marlin_not_supported,
39+
@pytest.mark.skipif(not is_quant_method_supported("marlin"),
4940
reason="Marlin is not supported on this GPU type.")
5041
@pytest.mark.parametrize("model_pair", model_pairs)
5142
@pytest.mark.parametrize("dtype", ["half"])

tests/quantization/test_bitsandbytes.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@
55
import pytest
66
import torch
77

8+
from tests.quantization.utils import is_quant_method_supported
89
from vllm import SamplingParams
9-
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1010

11-
capability = torch.cuda.get_device_capability()
12-
capability = capability[0] * 10 + capability[1]
1311

14-
15-
@pytest.mark.skipif(
16-
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
17-
reason='bitsandbytes is not supported on this GPU type.')
12+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
13+
reason='bitsandbytes is not supported on this GPU type.')
1814
def test_load_bnb_model(vllm_runner) -> None:
1915
with vllm_runner('huggyllama/llama-7b',
2016
quantization='bitsandbytes',

tests/quantization/test_fp8.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,13 @@
55
import pytest
66
import torch
77

8+
from tests.quantization.utils import is_quant_method_supported
89
from vllm._custom_ops import scaled_fp8_quant
9-
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1010
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
1111

12-
capability = torch.cuda.get_device_capability()
13-
capability = capability[0] * 10 + capability[1]
1412

15-
16-
@pytest.mark.skipif(
17-
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
18-
reason="FP8 is not supported on this GPU type.")
13+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
14+
reason="FP8 is not supported on this GPU type.")
1915
def test_load_fp16_model(vllm_runner) -> None:
2016
with vllm_runner("facebook/opt-125m", quantization="fp8") as llm:
2117

@@ -25,9 +21,8 @@ def test_load_fp16_model(vllm_runner) -> None:
2521
assert fc1.weight.dtype == torch.float8_e4m3fn
2622

2723

28-
@pytest.mark.skipif(
29-
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
30-
reason="FP8 is not supported on this GPU type.")
24+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
25+
reason="FP8 is not supported on this GPU type.")
3126
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
3227
def test_scaled_fp8_quant(dtype) -> None:
3328

tests/quantization/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
4+
5+
6+
def is_quant_method_supported(quant_method: str) -> bool:
7+
# Currently, all quantization methods require Nvidia or AMD GPUs
8+
if not torch.cuda.is_available():
9+
return False
10+
11+
capability = torch.cuda.get_device_capability()
12+
capability = capability[0] * 10 + capability[1]
13+
return (capability <
14+
QUANTIZATION_METHODS[quant_method].get_min_capability())

0 commit comments

Comments
 (0)