Skip to content

Commit c53c427

Browse files
authored
chore: lazy import check for quantize api (#630)
1 parent ca27b04 commit c53c427

File tree

3 files changed

+14
-26
lines changed

3 files changed

+14
-26
lines changed

src/cache_dit/__init__.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,7 @@
3838
from cache_dit.profiler import create_profiler_context
3939
from cache_dit.profiler import get_profiler_output_dir
4040
from cache_dit.profiler import set_profiler_output_dir
41-
42-
try:
43-
from cache_dit.quantize import quantize
44-
except ImportError as e: # noqa: F841
45-
err_msg = str(e)
46-
47-
def _raise_import_error(func_name: str):
48-
raise ImportError(
49-
f"{func_name} requires additional dependencies. "
50-
"Please install cache-dit[quantization] or cache-dit[all] "
51-
f"to use this feature. Error message: {err_msg}"
52-
)
53-
54-
def quantize(*args, **kwargs):
55-
_raise_import_error("quantize")
56-
41+
from cache_dit.quantize import quantize
5742

5843
NONE = CacheType.NONE
5944
DBCache = CacheType.DBCache

src/cache_dit/quantize/__init__.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
try:
2-
import torchao
3-
except ImportError:
4-
raise ImportError(
5-
"Quantization functionality requires the 'quantization' extra dependencies. "
6-
"Install with: pip install cache-dit[quantization]"
7-
)
81
import torch
92
from typing import Callable, Optional, List
103
from cache_dit.logger import init_logger
@@ -16,6 +9,8 @@ def quantize(
169
module: torch.nn.Module,
1710
quant_type: Optional[str] = None,
1811
backend: str = "ao",
12+
# Specific parameters for torchao backend
13+
per_row: bool = True,
1914
exclude_layers: List[str] = [
2015
"embedder",
2116
"embed",
@@ -35,7 +30,7 @@ def quantize(
3530
return quantize_ao(
3631
module,
3732
quant_type=quant_type,
38-
per_row=kwargs.pop("per_row", True),
33+
per_row=per_row,
3934
exclude_layers=exclude_layers,
4035
filter_fn=filter_fn,
4136
**kwargs,

src/cache_dit/quantize/torchao/quantize_ao.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,27 @@
1010
def quantize_ao(
1111
module: torch.nn.Module,
1212
quant_type: str = "float8_weight_only",
13+
# Paramters for FP8 DQ quantization
14+
# Whether to quantize per row (True) or per tensor (False)
15+
per_row: bool = True,
1316
exclude_layers: List[str] = [
1417
"embedder",
1518
"embed",
1619
],
1720
filter_fn: Optional[Callable] = None,
18-
# paramters for fp8 quantization
19-
per_row: bool = True,
2021
**kwargs,
2122
) -> torch.nn.Module:
2223
# Apply FP8 DQ for module and skip any `embed` modules
2324
# by default to avoid non-trivial precision downgrade. Please
2425
# set `exclude_layers` as `[]` if you don't want this behavior.
2526
assert isinstance(module, torch.nn.Module)
27+
try:
28+
import torchao # noqa: F401
29+
except ImportError:
30+
raise ImportError(
31+
"Quantization functionality requires the 'quantization' extra dependencies. "
32+
"Install with: pip install cache-dit[quantization]"
33+
)
2634

2735
alias_map = {
2836
"float8": "fp8_w8a8_dq",

0 commit comments

Comments
 (0)