File tree Expand file tree Collapse file tree 3 files changed +14
-26
lines changed
Expand file tree Collapse file tree 3 files changed +14
-26
lines changed Original file line number Diff line number Diff line change 3838from cache_dit .profiler import create_profiler_context
3939from cache_dit .profiler import get_profiler_output_dir
4040from cache_dit .profiler import set_profiler_output_dir
41-
42- try :
43- from cache_dit .quantize import quantize
44- except ImportError as e : # noqa: F841
45- err_msg = str (e )
46-
47- def _raise_import_error (func_name : str ):
48- raise ImportError (
49- f"{ func_name } requires additional dependencies. "
50- "Please install cache-dit[quantization] or cache-dit[all] "
51- f"to use this feature. Error message: { err_msg } "
52- )
53-
54- def quantize (* args , ** kwargs ):
55- _raise_import_error ("quantize" )
56-
41+ from cache_dit .quantize import quantize
5742
5843NONE = CacheType .NONE
5944DBCache = CacheType .DBCache
Original file line number Diff line number Diff line change 1- try :
2- import torchao
3- except ImportError :
4- raise ImportError (
5- "Quantization functionality requires the 'quantization' extra dependencies. "
6- "Install with: pip install cache-dit[quantization]"
7- )
81import torch
92from typing import Callable , Optional , List
103from cache_dit .logger import init_logger
@@ -16,6 +9,8 @@ def quantize(
169 module : torch .nn .Module ,
1710 quant_type : Optional [str ] = None ,
1811 backend : str = "ao" ,
12+ # Specific parameters for torchao backend
13+ per_row : bool = True ,
1914 exclude_layers : List [str ] = [
2015 "embedder" ,
2116 "embed" ,
@@ -35,7 +30,7 @@ def quantize(
3530 return quantize_ao (
3631 module ,
3732 quant_type = quant_type ,
38- per_row = kwargs . pop ( " per_row" , True ) ,
33+ per_row = per_row ,
3934 exclude_layers = exclude_layers ,
4035 filter_fn = filter_fn ,
4136 ** kwargs ,
Original file line number Diff line number Diff line change 1010def quantize_ao (
1111 module : torch .nn .Module ,
1212 quant_type : str = "float8_weight_only" ,
13+ # Paramters for FP8 DQ quantization
14+ # Whether to quantize per row (True) or per tensor (False)
15+ per_row : bool = True ,
1316 exclude_layers : List [str ] = [
1417 "embedder" ,
1518 "embed" ,
1619 ],
1720 filter_fn : Optional [Callable ] = None ,
18- # paramters for fp8 quantization
19- per_row : bool = True ,
2021 ** kwargs ,
2122) -> torch .nn .Module :
2223 # Apply FP8 DQ for module and skip any `embed` modules
2324 # by default to avoid non-trivial precision downgrade. Please
2425 # set `exclude_layers` as `[]` if you don't want this behavior.
2526 assert isinstance (module , torch .nn .Module )
27+ try :
28+ import torchao # noqa: F401
29+ except ImportError :
30+ raise ImportError (
31+ "Quantization functionality requires the 'quantization' extra dependencies. "
32+ "Install with: pip install cache-dit[quantization]"
33+ )
2634
2735 alias_map = {
2836 "float8" : "fp8_w8a8_dq" ,
You can’t perform that action at this time.
0 commit comments