Skip to content

Commit a93c429

Browse files
committed
feat: support aobaseconfig classes.
1 parent 4acbfbf commit a93c429

File tree

3 files changed

+228
-39
lines changed

3 files changed

+228
-39
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 134 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
"""
2222

2323
import copy
24+
import dataclasses
2425
import importlib.metadata
2526
import inspect
2627
import json
2728
import os
2829
import warnings
29-
from dataclasses import dataclass
30+
from dataclasses import dataclass, is_dataclass
3031
from enum import Enum
3132
from functools import partial
3233
from typing import Any, Callable, Dict, List, Optional, Union
@@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
443444
"""This is a config class for torchao quantization/sparsity techniques.
444445
445446
Args:
446-
quant_type (`str`):
447+
quant_type (Union[`str`, AOBaseConfig]):
447448
The type of quantization we want to use, currently supporting:
448449
- **Integer quantization:**
449450
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
465466
- **Unsigned Integer quantization:**
466467
- Full function names: `uintx_weight_only`
467468
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
469+
- An AOBaseConfig instance: for more advanced configuration options.
468470
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
469471
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
470472
modules left in their original precision.
@@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
478480
```python
479481
from diffusers import FluxTransformer2DModel, TorchAoConfig
480482
483+
# AOBaseConfig-based configuration
484+
from torchao.quantization import Int8WeightOnlyConfig
485+
486+
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
487+
488+
# String-based config
481489
quantization_config = TorchAoConfig("int8wo")
482490
transformer = FluxTransformer2DModel.from_pretrained(
483491
"black-forest-labs/Flux.1-Dev",
@@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin):
490498

491499
def __init__(
492500
self,
493-
quant_type: str,
501+
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
494502
modules_to_not_convert: Optional[List[str]] = None,
495503
**kwargs,
496504
) -> None:
@@ -504,8 +512,13 @@ def __init__(
504512
else:
505513
self.quant_type_kwargs = kwargs
506514

515+
self.post_init()
516+
517+
def post_init(self):
507518
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
508-
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
519+
AO_VERSION = self._get_ao_version()
520+
521+
if isinstance(self.quant_type, str) and self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
509522
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
510523
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
511524
raise ValueError(
@@ -517,22 +530,95 @@ def __init__(
517530
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
518531
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
519532
)
533+
elif AO_VERSION > version.parse("0.9.0"):
534+
from torchao.quantization.quant_api import AOBaseConfig
520535

521-
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
522-
signature = inspect.signature(method)
523-
all_kwargs = {
524-
param.name
525-
for param in signature.parameters.values()
526-
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
527-
}
528-
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
529-
530-
if len(unsupported_kwargs) > 0:
536+
if not isinstance(self.quant_type, AOBaseConfig):
537+
raise TypeError(
538+
f"`quant_type` must be either a string or an `AOBaseConfig` instance, got {type(self.quant_type)}."
539+
)
540+
else:
531541
raise ValueError(
532-
f'The quantization method "{quant_type}" does not support the following keyword arguments: '
533-
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
542+
f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. "
543+
f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances."
534544
)
535545

546+
if isinstance(self.quant_type, str):
547+
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
548+
signature = inspect.signature(method)
549+
all_kwargs = {
550+
param.name
551+
for param in signature.parameters.values()
552+
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
553+
}
554+
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
555+
556+
if len(unsupported_kwargs) > 0:
557+
raise ValueError(
558+
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
559+
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
560+
)
561+
562+
def to_dict(self):
563+
"""Convert configuration to a dictionary."""
564+
d = super().to_dict()
565+
566+
if isinstance(self.quant_type, str):
567+
# Handle layout serialization if present
568+
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
569+
if is_dataclass(d["quant_type_kwargs"]["layout"]):
570+
d["quant_type_kwargs"]["layout"] = [
571+
d["quant_type_kwargs"]["layout"].__class__.__name__,
572+
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
573+
]
574+
if isinstance(d["quant_type_kwargs"]["layout"], list):
575+
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
576+
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
577+
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
578+
else:
579+
raise ValueError("layout must be a list")
580+
else:
581+
# Handle AOBaseConfig serialization
582+
from torchao.core.config import config_to_dict
583+
584+
# For now we assume there is 1 config per Transformer, however in the future
585+
# We may want to support a config per fqn.
586+
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
587+
588+
return d
589+
590+
@classmethod
591+
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
592+
"""Create configuration from a dictionary."""
593+
ao_version = cls._get_ao_version()
594+
assert ao_version > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
595+
config_dict = config_dict.copy()
596+
quant_type = config_dict.pop("quant_type")
597+
598+
if isinstance(quant_type, str):
599+
return cls(quant_type=quant_type, **config_dict)
600+
# Check if we only have one key which is "default"
601+
# In the future we may update this
602+
assert len(quant_type) == 1 and "default" in quant_type, (
603+
"Expected only one key 'default' in quant_type dictionary"
604+
)
605+
quant_type = quant_type["default"]
606+
607+
# Deserialize quant_type if needed
608+
from torchao.core.config import config_from_dict
609+
610+
quant_type = config_from_dict(quant_type)
611+
612+
return cls(quant_type=quant_type, **config_dict)
613+
614+
@staticmethod
615+
def _get_ao_version() -> version.Version:
616+
"""Centralized check for TorchAO availability and version requirements."""
617+
if not is_torchao_available():
618+
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")
619+
620+
return version.parse(importlib.metadata.version("torchao"))
621+
536622
@classmethod
537623
def _get_torchao_quant_type_to_method(cls):
538624
r"""
@@ -681,8 +767,38 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
681767
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
682768

683769
def get_apply_tensor_subclass(self):
684-
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
685-
return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
770+
"""Create the appropriate quantization method based on configuration."""
771+
if isinstance(self.quant_type, str):
772+
methods = self._get_torchao_quant_type_to_method()
773+
quant_type_kwargs = self.quant_type_kwargs.copy()
774+
if (
775+
not torch.cuda.is_available()
776+
and is_torchao_available()
777+
and self.quant_type == "int4_weight_only"
778+
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
779+
and quant_type_kwargs.get("layout", None) is None
780+
):
781+
if torch.xpu.is_available():
782+
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
783+
"0.11.0"
784+
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
785+
from torchao.dtypes import Int4XPULayout
786+
from torchao.quantization.quant_primitives import ZeroPointDomain
787+
788+
quant_type_kwargs["layout"] = Int4XPULayout()
789+
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
790+
else:
791+
raise ValueError(
792+
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
793+
)
794+
else:
795+
from torchao.dtypes import Int4CPULayout
796+
797+
quant_type_kwargs["layout"] = Int4CPULayout()
798+
799+
return methods[self.quant_type](**quant_type_kwargs)
800+
else:
801+
return self.quant_type
686802

687803
def __repr__(self):
688804
r"""

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
"""
1919

2020
import importlib
21+
import re
2122
import types
2223
from fnmatch import fnmatch
23-
from typing import TYPE_CHECKING, Any, Dict, List, Union
24+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2425

2526
from packaging import version
2627

@@ -107,6 +108,21 @@ def _update_torch_safe_globals():
107108
_update_torch_safe_globals()
108109

109110

111+
def fuzzy_match_size(config_name: str) -> Optional[str]:
112+
"""
113+
Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise
114+
None.
115+
"""
116+
config_name = config_name.lower()
117+
118+
str_match = re.search(r"(\d)weight", config_name)
119+
120+
if str_match:
121+
return str_match.group(1)
122+
123+
return None
124+
125+
110126
logger = logging.get_logger(__name__)
111127

112128

@@ -176,8 +192,7 @@ def validate_environment(self, *args, **kwargs):
176192

177193
def update_torch_dtype(self, torch_dtype):
178194
quant_type = self.quantization_config.quant_type
179-
180-
if quant_type.startswith("int") or quant_type.startswith("uint"):
195+
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
181196
if torch_dtype is not None and torch_dtype != torch.bfloat16:
182197
logger.warning(
183198
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
@@ -197,24 +212,44 @@ def update_torch_dtype(self, torch_dtype):
197212

198213
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
199214
quant_type = self.quantization_config.quant_type
200-
201-
if quant_type.startswith("int8") or quant_type.startswith("int4"):
202-
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
203-
return torch.int8
204-
elif quant_type == "uintx_weight_only":
205-
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
206-
elif quant_type.startswith("uint"):
207-
return {
208-
1: torch.uint1,
209-
2: torch.uint2,
210-
3: torch.uint3,
211-
4: torch.uint4,
212-
5: torch.uint5,
213-
6: torch.uint6,
214-
7: torch.uint7,
215-
}[int(quant_type[4])]
216-
elif quant_type.startswith("float") or quant_type.startswith("fp"):
217-
return torch.bfloat16
215+
from accelerate.utils import CustomDtype
216+
217+
if isinstance(quant_type, str):
218+
if quant_type.startswith("int8"):
219+
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
220+
return torch.int8
221+
elif quant_type.startswith("int4"):
222+
return CustomDtype.INT4
223+
elif quant_type == "uintx_weight_only":
224+
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
225+
elif quant_type.startswith("uint"):
226+
return {
227+
1: torch.uint1,
228+
2: torch.uint2,
229+
3: torch.uint3,
230+
4: torch.uint4,
231+
5: torch.uint5,
232+
6: torch.uint6,
233+
7: torch.uint7,
234+
}[int(quant_type[4])]
235+
elif quant_type.startswith("float") or quant_type.startswith("fp"):
236+
return torch.bfloat16
237+
238+
elif self.quantization_config._get_ao_version() > version.Version("0.9.0"):
239+
from torchao.core.config import AOBaseConfig
240+
241+
quant_type = self.quantization_config.quant_type
242+
if isinstance(quant_type, AOBaseConfig):
243+
# Extract size digit using fuzzy match on the class name
244+
config_name = quant_type.__class__.__name__
245+
size_digit = fuzzy_match_size(config_name)
246+
247+
# Map the extracted digit to appropriate dtype
248+
if size_digit == "4":
249+
return CustomDtype.INT4
250+
else:
251+
# Default to int8
252+
return torch.int8
218253

219254
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
220255
return target_dtype
@@ -297,6 +332,21 @@ def get_cuda_warm_up_factor(self):
297332
# Original mapping for non-AOBaseConfig types
298333
# For the uint types, this is a best guess. Once these types become more used
299334
# we can look into their nuances.
335+
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
336+
from torchao.core.config import AOBaseConfig
337+
338+
quant_type = self.quantization_config.quant_type
339+
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
340+
if isinstance(quant_type, AOBaseConfig):
341+
# Extract size digit using fuzzy match on the class name
342+
config_name = quant_type.__class__.__name__
343+
size_digit = fuzzy_match_size(config_name)
344+
345+
if size_digit == "4":
346+
return 8
347+
else:
348+
return 4
349+
300350
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
301351
quant_type = self.quantization_config.quant_type
302352
for pattern, target_dtype in map_to_target_dtype.items():

0 commit comments

Comments
 (0)