22
33from __future__ import annotations
44
5- from typing import Any , Union
5+ from typing import Any , Optional , Union
66
77import pytest
88import torch
1515from ..utils import create_new_process_for_each_test
1616
1717
18- def models_list (all : bool ):
18+ def models_list (* , all : bool = True , keywords : Optional [ list [ str ]] = None ):
1919 TEST_MODELS : list [tuple [str , dict [str , Any ]]] = [
2020 ("facebook/opt-125m" , {}),
2121 ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" , {
@@ -32,47 +32,50 @@ def models_list(all: bool):
3232 ("meta-llama/Llama-3.2-1B-Instruct" , {}),
3333 ]
3434
35- if not all :
36- return TEST_MODELS
35+ if all :
36+ if is_quant_method_supported ("aqlm" ):
37+ TEST_MODELS .append (("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" , {
38+ "quantization" : "aqlm"
39+ }))
40+
41+ # TODO: figure out why this fails.
42+ if False and is_quant_method_supported ("gguf" ): # noqa: SIM223
43+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" , {
44+ "quantization" : "gguf"
45+ }))
46+
47+ if is_quant_method_supported ("gptq" ):
48+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ" , {
49+ "quantization" : "gptq"
50+ }))
51+
52+ if is_quant_method_supported ("gptq_marlin" ):
53+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ" , {
54+ "quantization" : "gptq_marlin"
55+ }))
3756
38- if is_quant_method_supported ("aqlm" ):
39- TEST_MODELS .append (("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" , {
40- "quantization" : "aqlm"
41- }))
42-
43- # TODO: figure out why this fails.
44- if False and is_quant_method_supported ("gguf" ): # noqa: SIM223
45- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" , {
46- "quantization" : "gguf"
47- }))
48-
49- if is_quant_method_supported ("gptq" ):
50- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ" , {
51- "quantization" : "gptq"
52- }))
53-
54- if is_quant_method_supported ("gptq_marlin" ):
55- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ" , {
56- "quantization" : "gptq_marlin"
57- }))
58-
59- if is_quant_method_supported ("gptq_marlin_24" ):
60- TEST_MODELS .append (("alexm-nm/tinyllama-24-marlin24-4bit-g128" , {
61- "quantization" : "gptq_marlin_24"
62- }))
63-
64- if is_quant_method_supported ("marlin" ):
65- TEST_MODELS .append (
66- ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin" , {
67- "quantization" : "marlin"
57+ if is_quant_method_supported ("gptq_marlin_24" ):
58+ TEST_MODELS .append (("alexm-nm/tinyllama-24-marlin24-4bit-g128" , {
59+ "quantization" : "gptq_marlin_24"
6860 }))
6961
70- if not current_platform .is_rocm () and is_quant_method_supported ("awq" ):
71- TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ" , {
72- "quantization" : "AWQ"
73- }))
62+ if is_quant_method_supported ("marlin" ):
63+ TEST_MODELS .append (
64+ ("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin" , {
65+ "quantization" : "marlin"
66+ }))
7467
75- return TEST_MODELS
68+ if not current_platform .is_rocm () and is_quant_method_supported ("awq" ):
69+ TEST_MODELS .append (("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ" , {
70+ "quantization" : "AWQ"
71+ }))
72+
73+ if keywords is None :
74+ return TEST_MODELS
75+
76+ # filter by keywords
77+ pred = lambda model : any (keyword in model [0 ] for keyword in keywords )
78+ return list (filter (pred , TEST_MODELS ))
7679
7780
7881@pytest .mark .parametrize (
@@ -96,20 +99,30 @@ def test_full_graph(
9699 run_model (optimization_level , model , model_kwargs )
97100
98101
102+ PassConfig = CompilationConfig .PassConfig
103+
104+
99105# TODO(luka) add other supported compilation config scenarios here
100106@pytest .mark .parametrize (
101- "compilation_config" ,
102- # additional compile sizes
107+ "compilation_config, model_info" ,
103108 [
104- CompilationConfig (level = CompilationLevel .PIECEWISE ,
105- compile_sizes = [1 , 2 ])
109+ # additional compile sizes, only some of the models
110+ (CompilationConfig (level = CompilationLevel .PIECEWISE ,
111+ compile_sizes = [1 , 2 ]), model )
112+ for model in models_list (all = False )
113+ ] + [
114+ # RMSNorm + quant fusion, only 8-bit quant models
115+ (CompilationConfig (level = CompilationLevel .PIECEWISE ,
116+ custom_ops = ["+rms_norm" ],
117+ pass_config = PassConfig (enable_fusion = True ,
118+ enable_noop = True )), model )
119+ for model in models_list (keywords = ["FP8-dynamic" , "quantized.w8a8" ])
106120 ])
107121# only test some of the models
108- @pytest .mark .parametrize ("model_info" , models_list (all = False ))
109122@create_new_process_for_each_test ()
110123def test_custom_compile_config (
111- model_info : tuple [str , dict [str , Any ]],
112124 compilation_config : CompilationConfig ,
125+ model_info : tuple [str , dict [str , Any ]],
113126):
114127 model , model_kwargs = model_info
115128 print (f"MODEL={ model } " )
0 commit comments