2
2
3
3
from __future__ import annotations
4
4
5
- from typing import Any , Union
5
+ from typing import Any , Optional , Union
6
6
7
7
import pytest
8
8
import torch
15
15
from ..utils import create_new_process_for_each_test
16
16
17
17
18
- def models_list (all : bool ):
18
+ def models_list (* , all : bool = True , keywords : Optional [ list [ str ]] = None ):
19
19
TEST_MODELS : list [tuple [str , dict [str , Any ]]] = [
20
20
("facebook/opt-125m" , {}),
21
21
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" , {
@@ -32,47 +32,50 @@ def models_list(all: bool):
32
32
("meta-llama/Llama-3.2-1B-Instruct" , {}),
33
33
]
34
34
35
- if not all :
36
- return TEST_MODELS
35
+ if all :
36
+ if is_quant_method_supported ("aqlm" ):
37
+ TEST_MODELS .append (("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" , {
38
+ "quantization" : "aqlm"
39
+ }))
40
+
41
+ # TODO: figure out why this fails.
42
+ if False and is_quant_method_supported ("gguf" ): # noqa: SIM223
43
+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" , {
44
+ "quantization" : "gguf"
45
+ }))
46
+
47
+ if is_quant_method_supported ("gptq" ):
48
+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ" , {
49
+ "quantization" : "gptq"
50
+ }))
51
+
52
+ if is_quant_method_supported ("gptq_marlin" ):
53
+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ" , {
54
+ "quantization" : "gptq_marlin"
55
+ }))
37
56
38
- if is_quant_method_supported ("aqlm" ):
39
- TEST_MODELS .append (("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" , {
40
- "quantization" : "aqlm"
41
- }))
42
-
43
- # TODO: figure out why this fails.
44
- if False and is_quant_method_supported ("gguf" ): # noqa: SIM223
45
- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" , {
46
- "quantization" : "gguf"
47
- }))
48
-
49
- if is_quant_method_supported ("gptq" ):
50
- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ" , {
51
- "quantization" : "gptq"
52
- }))
53
-
54
- if is_quant_method_supported ("gptq_marlin" ):
55
- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ" , {
56
- "quantization" : "gptq_marlin"
57
- }))
58
-
59
- if is_quant_method_supported ("gptq_marlin_24" ):
60
- TEST_MODELS .append (("alexm-nm/tinyllama-24-marlin24-4bit-g128" , {
61
- "quantization" : "gptq_marlin_24"
62
- }))
63
-
64
- if is_quant_method_supported ("marlin" ):
65
- TEST_MODELS .append (
66
- ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin" , {
67
- "quantization" : "marlin"
57
+ if is_quant_method_supported ("gptq_marlin_24" ):
58
+ TEST_MODELS .append (("alexm-nm/tinyllama-24-marlin24-4bit-g128" , {
59
+ "quantization" : "gptq_marlin_24"
68
60
}))
69
61
70
- if not current_platform .is_rocm () and is_quant_method_supported ("awq" ):
71
- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ" , {
72
- "quantization" : "AWQ"
73
- }))
62
+ if is_quant_method_supported ("marlin" ):
63
+ TEST_MODELS .append (
64
+ ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin" , {
65
+ "quantization" : "marlin"
66
+ }))
74
67
75
- return TEST_MODELS
68
+ if not current_platform .is_rocm () and is_quant_method_supported ("awq" ):
69
+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ" , {
70
+ "quantization" : "AWQ"
71
+ }))
72
+
73
+ if keywords is None :
74
+ return TEST_MODELS
75
+
76
+ # filter by keywords
77
+ pred = lambda model : any (keyword in model [0 ] for keyword in keywords )
78
+ return list (filter (pred , TEST_MODELS ))
76
79
77
80
78
81
@pytest .mark .parametrize (
@@ -96,20 +99,30 @@ def test_full_graph(
96
99
run_model (optimization_level , model , model_kwargs )
97
100
98
101
102
+ PassConfig = CompilationConfig .PassConfig
103
+
104
+
99
105
# TODO(luka) add other supported compilation config scenarios here
100
106
@pytest .mark .parametrize (
101
- "compilation_config" ,
102
- # additional compile sizes
107
+ "compilation_config, model_info" ,
103
108
[
104
- CompilationConfig (level = CompilationLevel .PIECEWISE ,
105
- compile_sizes = [1 , 2 ])
109
+ # additional compile sizes, only some of the models
110
+ (CompilationConfig (level = CompilationLevel .PIECEWISE ,
111
+ compile_sizes = [1 , 2 ]), model )
112
+ for model in models_list (all = False )
113
+ ] + [
114
+ # RMSNorm + quant fusion, only 8-bit quant models
115
+ (CompilationConfig (level = CompilationLevel .PIECEWISE ,
116
+ custom_ops = ["+rms_norm" ],
117
+ pass_config = PassConfig (enable_fusion = True ,
118
+ enable_noop = True )), model )
119
+ for model in models_list (keywords = ["FP8-dynamic" , "quantized.w8a8" ])
106
120
])
107
121
# only test some of the models
108
- @pytest .mark .parametrize ("model_info" , models_list (all = False ))
109
122
@create_new_process_for_each_test ()
110
123
def test_custom_compile_config (
111
- model_info : tuple [str , dict [str , Any ]],
112
124
compilation_config : CompilationConfig ,
125
+ model_info : tuple [str , dict [str , Any ]],
113
126
):
114
127
model , model_kwargs = model_info
115
128
print (f"MODEL={ model } " )
0 commit comments