2121"""
2222
2323import copy
24+ import dataclasses
2425import importlib .metadata
2526import inspect
2627import json
2728import os
2829import warnings
29- from dataclasses import dataclass
30+ from dataclasses import dataclass , is_dataclass
3031from enum import Enum
3132from functools import partial
3233from typing import Any , Callable , Dict , List , Optional , Union
@@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
443444 """This is a config class for torchao quantization/sparsity techniques.
444445
445446 Args:
446- quant_type (`str`):
447+ quant_type (Union[ `str`, AOBaseConfig] ):
447448 The type of quantization we want to use, currently supporting:
448449 - **Integer quantization:**
449450 - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
465466 - **Unsigned Integer quantization:**
466467 - Full function names: `uintx_weight_only`
467468 - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
469+ - An AOBaseConfig instance: for more advanced configuration options.
468470 modules_to_not_convert (`List[str]`, *optional*, default to `None`):
469471 The list of modules to not quantize, useful for quantizing models that explicitly require to have some
470472 modules left in their original precision.
@@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
478480 ```python
479481 from diffusers import FluxTransformer2DModel, TorchAoConfig
480482
483+ # AOBaseConfig-based configuration
484+ from torchao.quantization import Int8WeightOnlyConfig
485+
486+ quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
487+
488+ # String-based config
481489 quantization_config = TorchAoConfig("int8wo")
482490 transformer = FluxTransformer2DModel.from_pretrained(
483491 "black-forest-labs/Flux.1-Dev",
@@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin):
490498
491499 def __init__ (
492500 self ,
493- quant_type : str ,
501+ quant_type : Union [ str , "AOBaseConfig" ], # noqa: F821
494502 modules_to_not_convert : Optional [List [str ]] = None ,
495503 ** kwargs ,
496504 ) -> None :
@@ -504,8 +512,13 @@ def __init__(
504512 else :
505513 self .quant_type_kwargs = kwargs
506514
515+ self .post_init ()
516+
517+ def post_init (self ):
507518 TORCHAO_QUANT_TYPE_METHODS = self ._get_torchao_quant_type_to_method ()
508- if self .quant_type not in TORCHAO_QUANT_TYPE_METHODS .keys ():
519+ AO_VERSION = self ._get_ao_version ()
520+
521+ if isinstance (self .quant_type , str ) and self .quant_type not in TORCHAO_QUANT_TYPE_METHODS .keys ():
509522 is_floating_quant_type = self .quant_type .startswith ("float" ) or self .quant_type .startswith ("fp" )
510523 if is_floating_quant_type and not self ._is_xpu_or_cuda_capability_atleast_8_9 ():
511524 raise ValueError (
@@ -517,22 +530,95 @@ def __init__(
517530 f"Requested quantization type: { self .quant_type } is not supported or is an incorrect `quant_type` name. If you think the "
518531 f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
519532 )
533+ elif AO_VERSION > version .parse ("0.9.0" ):
534+ from torchao .quantization .quant_api import AOBaseConfig
520535
521- method = TORCHAO_QUANT_TYPE_METHODS [self .quant_type ]
522- signature = inspect .signature (method )
523- all_kwargs = {
524- param .name
525- for param in signature .parameters .values ()
526- if param .kind in [inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ]
527- }
528- unsupported_kwargs = list (self .quant_type_kwargs .keys () - all_kwargs )
529-
530- if len (unsupported_kwargs ) > 0 :
536+ if not isinstance (self .quant_type , AOBaseConfig ):
537+ raise TypeError (
538+ f"`quant_type` must be either a string or an `AOBaseConfig` instance, got { type (self .quant_type )} ."
539+ )
540+ else :
531541 raise ValueError (
532- f'The quantization method " { quant_type } " does not support the following keyword arguments: '
533- f"{ unsupported_kwargs } . The following keywords arguments are supported: { all_kwargs } ."
542+ f"In torchao <= 0.9.0, quant_type must be a string. Got { type ( self . quant_type ) } . "
543+ f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances ."
534544 )
535545
546+ if isinstance (self .quant_type , str ):
547+ method = TORCHAO_QUANT_TYPE_METHODS [self .quant_type ]
548+ signature = inspect .signature (method )
549+ all_kwargs = {
550+ param .name
551+ for param in signature .parameters .values ()
552+ if param .kind in [inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ]
553+ }
554+ unsupported_kwargs = list (self .quant_type_kwargs .keys () - all_kwargs )
555+
556+ if len (unsupported_kwargs ) > 0 :
557+ raise ValueError (
558+ f'The quantization method "{ self .quant_type } " does not support the following keyword arguments: '
559+ f"{ unsupported_kwargs } . The following keywords arguments are supported: { all_kwargs } ."
560+ )
561+
562+ def to_dict (self ):
563+ """Convert configuration to a dictionary."""
564+ d = super ().to_dict ()
565+
566+ if isinstance (self .quant_type , str ):
567+ # Handle layout serialization if present
568+ if "quant_type_kwargs" in d and "layout" in d ["quant_type_kwargs" ]:
569+ if is_dataclass (d ["quant_type_kwargs" ]["layout" ]):
570+ d ["quant_type_kwargs" ]["layout" ] = [
571+ d ["quant_type_kwargs" ]["layout" ].__class__ .__name__ ,
572+ dataclasses .asdict (d ["quant_type_kwargs" ]["layout" ]),
573+ ]
574+ if isinstance (d ["quant_type_kwargs" ]["layout" ], list ):
575+ assert len (d ["quant_type_kwargs" ]["layout" ]) == 2 , "layout saves layout name and layout kwargs"
576+ assert isinstance (d ["quant_type_kwargs" ]["layout" ][0 ], str ), "layout name must be a string"
577+ assert isinstance (d ["quant_type_kwargs" ]["layout" ][1 ], dict ), "layout kwargs must be a dict"
578+ else :
579+ raise ValueError ("layout must be a list" )
580+ else :
581+ # Handle AOBaseConfig serialization
582+ from torchao .core .config import config_to_dict
583+
584+ # For now we assume there is 1 config per Transformer, however in the future
585+ # We may want to support a config per fqn.
586+ d ["quant_type" ] = {"default" : config_to_dict (self .quant_type )}
587+
588+ return d
589+
590+ @classmethod
591+ def from_dict (cls , config_dict , return_unused_kwargs = False , ** kwargs ):
592+ """Create configuration from a dictionary."""
593+ ao_version = cls ._get_ao_version ()
594+ assert ao_version > version .parse ("0.9.0" ), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
595+ config_dict = config_dict .copy ()
596+ quant_type = config_dict .pop ("quant_type" )
597+
598+ if isinstance (quant_type , str ):
599+ return cls (quant_type = quant_type , ** config_dict )
600+ # Check if we only have one key which is "default"
601+ # In the future we may update this
602+ assert len (quant_type ) == 1 and "default" in quant_type , (
603+ "Expected only one key 'default' in quant_type dictionary"
604+ )
605+ quant_type = quant_type ["default" ]
606+
607+ # Deserialize quant_type if needed
608+ from torchao .core .config import config_from_dict
609+
610+ quant_type = config_from_dict (quant_type )
611+
612+ return cls (quant_type = quant_type , ** config_dict )
613+
614+ @staticmethod
615+ def _get_ao_version () -> version .Version :
616+ """Centralized check for TorchAO availability and version requirements."""
617+ if not is_torchao_available ():
618+ raise ValueError ("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`" )
619+
620+ return version .parse (importlib .metadata .version ("torchao" ))
621+
536622 @classmethod
537623 def _get_torchao_quant_type_to_method (cls ):
538624 r"""
@@ -681,8 +767,38 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
681767 raise RuntimeError ("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch." )
682768
683769 def get_apply_tensor_subclass (self ):
684- TORCHAO_QUANT_TYPE_METHODS = self ._get_torchao_quant_type_to_method ()
685- return TORCHAO_QUANT_TYPE_METHODS [self .quant_type ](** self .quant_type_kwargs )
770+ """Create the appropriate quantization method based on configuration."""
771+ if isinstance (self .quant_type , str ):
772+ methods = self ._get_torchao_quant_type_to_method ()
773+ quant_type_kwargs = self .quant_type_kwargs .copy ()
774+ if (
775+ not torch .cuda .is_available ()
776+ and is_torchao_available ()
777+ and self .quant_type == "int4_weight_only"
778+ and version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.8.0" )
779+ and quant_type_kwargs .get ("layout" , None ) is None
780+ ):
781+ if torch .xpu .is_available ():
782+ if version .parse (importlib .metadata .version ("torchao" )) >= version .parse (
783+ "0.11.0"
784+ ) and version .parse (importlib .metadata .version ("torch" )) > version .parse ("2.7.9" ):
785+ from torchao .dtypes import Int4XPULayout
786+ from torchao .quantization .quant_primitives import ZeroPointDomain
787+
788+ quant_type_kwargs ["layout" ] = Int4XPULayout ()
789+ quant_type_kwargs ["zero_point_domain" ] = ZeroPointDomain .INT
790+ else :
791+ raise ValueError (
792+ "TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
793+ )
794+ else :
795+ from torchao .dtypes import Int4CPULayout
796+
797+ quant_type_kwargs ["layout" ] = Int4CPULayout ()
798+
799+ return methods [self .quant_type ](** quant_type_kwargs )
800+ else :
801+ return self .quant_type
686802
687803 def __repr__ (self ):
688804 r"""
0 commit comments