From b139f85719437686625b76dfe0f7685797722a7c Mon Sep 17 00:00:00 2001 From: NJX-njx <3771829673@qq.com> Date: Wed, 4 Mar 2026 18:28:22 +0800 Subject: [PATCH] refactor: modernize type hints to Python 3.10+ syntax Ref #1927 Apply Python 3.10+ type hint modernization across 58 source files using pyupgrade --py310-plus: - Replace Union[X, Y] with X | Y (PEP 604) - Replace Optional[X] with X | None (PEP 604) - Replace List[X] with list[X] (PEP 585) - Replace Dict[K, V] with dict[K, V] (PEP 585) - Replace Tuple[X, ...] with tuple[X, ...] (PEP 585) - Replace Set[X] with set[X] (PEP 585) - Replace Type[X] with type[X] (PEP 585) - Remove now-unused typing imports This is the first batch targeting the src/llmcompressor/ directory. Tests directory will follow in a separate PR. --- src/llmcompressor/args/dataset_arguments.py | 2 +- src/llmcompressor/core/events/event.py | 8 ++--- src/llmcompressor/core/session.py | 4 ++- src/llmcompressor/core/session_functions.py | 8 +++-- src/llmcompressor/datasets/utils.py | 6 ++-- .../entrypoints/model_free/__init__.py | 6 ++-- .../entrypoints/model_free/helpers.py | 8 +++-- .../entrypoints/model_free/process.py | 2 +- .../model_free/reindex_fused_weights.py | 2 +- .../entrypoints/model_free/save_utils.py | 2 +- .../entrypoints/model_free/validate.py | 2 +- src/llmcompressor/entrypoints/oneshot.py | 4 ++- src/llmcompressor/logger.py | 10 +++--- src/llmcompressor/modeling/fuse.py | 2 +- src/llmcompressor/modeling/gpt_oss.py | 10 +++--- src/llmcompressor/modeling/llama4.py | 2 +- src/llmcompressor/modeling/qwen3_moe.py | 1 - src/llmcompressor/modifiers/autoround/base.py | 12 +++---- src/llmcompressor/modifiers/awq/base.py | 8 +++-- src/llmcompressor/modifiers/gptq/base.py | 14 ++++---- .../modifiers/pruning/constant/base.py | 4 +-- .../modifiers/pruning/helpers.py | 6 ++-- .../modifiers/pruning/magnitude/base.py | 6 ++-- .../modifiers/pruning/sparsegpt/sgpt_base.py | 4 +-- .../pruning/sparsegpt/sgpt_sparsify.py | 6 ++-- .../modifiers/pruning/wanda/wanda_sparsify.py | 4 +-- .../modifiers/quantization/calibration.py | 2 +- .../quantization/quantization/mixin.py | 34 +++++++++---------- .../modifiers/transform/quip/base.py | 10 +++--- .../modifiers/transform/smoothquant/base.py | 6 ++-- .../modifiers/transform/spinquant/base.py | 14 ++++---- .../modifiers/transform/spinquant/mappings.py | 8 ++--- .../transform/spinquant/norm_mappings.py | 6 ++-- src/llmcompressor/modifiers/utils/hooks.py | 4 ++- src/llmcompressor/observers/base.py | 14 ++++---- src/llmcompressor/observers/helpers.py | 4 +-- src/llmcompressor/observers/moving_base.py | 2 +- src/llmcompressor/observers/mse.py | 2 +- src/llmcompressor/pipelines/cache.py | 12 ++++--- .../pipelines/data_free/pipeline.py | 2 +- .../pipelines/sequential/helpers.py | 22 ++++++------ .../pipelines/sequential/pipeline.py | 4 ++- .../pytorch/model_load/helpers.py | 4 +-- .../pytorch/utils/sparsification.py | 2 +- .../utils/sparsification_info/configs.py | 26 +++++++------- .../utils/sparsification_info/helpers.py | 8 ++--- .../module_sparsification_info.py | 8 +++-- src/llmcompressor/recipe/recipe.py | 34 +++++++++---------- src/llmcompressor/recipe/utils.py | 8 ++--- src/llmcompressor/transformers/data/base.py | 10 +++--- .../transformers/tracing/debug.py | 4 +-- .../transformers/utils/helpers.py | 2 -- src/llmcompressor/typing.py | 2 +- src/llmcompressor/utils/dev.py | 2 +- src/llmcompressor/utils/dist.py | 4 ++- src/llmcompressor/utils/metric_logging.py | 12 +++---- src/llmcompressor/utils/pytorch/module.py | 8 ++--- src/llmcompressor/utils/transformers.py | 2 +- 58 files changed, 228 insertions(+), 197 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 60705744e2..0e2408e937 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -8,7 +8,7 @@ """ from dataclasses import dataclass, field -from typing import Callable +from collections.abc import Callable from datasets import Dataset, DatasetDict from torch.utils.data import DataLoader diff --git a/src/llmcompressor/core/events/event.py b/src/llmcompressor/core/events/event.py index a73903d4cf..4df4260dc4 100644 --- a/src/llmcompressor/core/events/event.py +++ b/src/llmcompressor/core/events/event.py @@ -84,9 +84,9 @@ class Event: :type global_batch: int """ - type_: Optional[EventType] = None - steps_per_epoch: Optional[int] = None - batches_per_step: Optional[int] = None + type_: EventType | None = None + steps_per_epoch: int | None = None + batches_per_step: int | None = None invocations_per_step: int = 1 global_step: int = 0 global_batch: int = 0 @@ -206,7 +206,7 @@ def current_index(self, value: float): ) def should_update( - self, start: Optional[float], end: Optional[float], update: Optional[float] + self, start: float | None, end: float | None, update: float | None ) -> bool: """ Determines if the event should trigger an update. diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index e86996a84c..aaabc62629 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -7,7 +7,9 @@ """ from dataclasses import dataclass -from typing import Any, Callable +from typing import Any + +from collections.abc import Callable from loguru import logger diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index e17ce4a818..51ebf139bd 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -7,7 +7,9 @@ import threading from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generator, Optional +from typing import TYPE_CHECKING, Any, Optional + +from collections.abc import Generator from loguru import logger @@ -91,7 +93,7 @@ def event(cls, event_type: EventType, **kwargs) -> ModifiedState: return active_session().event(event_type, **kwargs) @classmethod - def batch_start(cls, batch_data: Optional[Any] = None, **kwargs) -> ModifiedState: + def batch_start(cls, batch_data: Any | None = None, **kwargs) -> ModifiedState: """ Invoke a batch start event for the active session @@ -102,7 +104,7 @@ def batch_start(cls, batch_data: Optional[Any] = None, **kwargs) -> ModifiedStat return cls.event(EventType.BATCH_START, batch_data=batch_data, **kwargs) @classmethod - def loss_calculated(cls, loss: Optional[Any] = None, **kwargs) -> ModifiedState: + def loss_calculated(cls, loss: Any | None = None, **kwargs) -> ModifiedState: """ Invoke a loss calculated event for the active session diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 3478f08055..8c966e11d1 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -10,7 +10,9 @@ import math import re from collections.abc import Iterator, Sized -from typing import Any, Callable, Optional +from typing import Any, Optional + +from collections.abc import Callable import torch from datasets import Dataset @@ -334,7 +336,7 @@ class LengthAwareSampler(Sampler[int]): def __init__( self, data_source: Dataset, - num_samples: Optional[int] = None, + num_samples: int | None = None, batch_size: int = 1, ) -> None: self.data_source = data_source diff --git a/src/llmcompressor/entrypoints/model_free/__init__.py b/src/llmcompressor/entrypoints/model_free/__init__.py index 745ce86076..c299406f57 100644 --- a/src/llmcompressor/entrypoints/model_free/__init__.py +++ b/src/llmcompressor/entrypoints/model_free/__init__.py @@ -2,7 +2,9 @@ import shutil from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Iterable, Optional +from typing import Optional + +from collections.abc import Iterable import torch import tqdm @@ -40,7 +42,7 @@ def model_free_ptq( scheme: QuantizationScheme | str, ignore: Iterable[str] = tuple(), max_workers: int = 1, - device: Optional[torch.device | str] = None, + device: torch.device | str | None = None, ): """ Quantize a model without the need for a model definition. This function operates on diff --git a/src/llmcompressor/entrypoints/model_free/helpers.py b/src/llmcompressor/entrypoints/model_free/helpers.py index ef45f09346..15b28649f5 100644 --- a/src/llmcompressor/entrypoints/model_free/helpers.py +++ b/src/llmcompressor/entrypoints/model_free/helpers.py @@ -1,7 +1,9 @@ import os import re from collections import defaultdict -from typing import Mapping, TypeVar +from typing import TypeVar + +from collections.abc import Mapping import torch from compressed_tensors.utils.match import _match_name @@ -95,11 +97,11 @@ def natural_key(s: str) -> list[str | int]: ) # once we have a full set, yield and reset - if all((matches[target] is not None for target in targets)): + if all(matches[target] is not None for target in targets): matched_sets.append(matches) matches = dict.fromkeys(targets, None) - unmatched_set = matches if any((v is not None for v in matches.values())) else None + unmatched_set = matches if any(v is not None for v in matches.values()) else None if return_unmatched: return matched_sets, unmatched_set diff --git a/src/llmcompressor/entrypoints/model_free/process.py b/src/llmcompressor/entrypoints/model_free/process.py index 0a3d86efec..a947db9d23 100644 --- a/src/llmcompressor/entrypoints/model_free/process.py +++ b/src/llmcompressor/entrypoints/model_free/process.py @@ -1,7 +1,7 @@ import os from collections import defaultdict from collections.abc import Iterator, Mapping -from typing import Iterable +from collections.abc import Iterable import torch from compressed_tensors.quantization import QuantizationScheme diff --git a/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py b/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py index a5b5bbd2d5..e38a657cdb 100644 --- a/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py +++ b/src/llmcompressor/entrypoints/model_free/reindex_fused_weights.py @@ -77,7 +77,7 @@ def reindex_fused_weights( shutil.copyfile(resolved_path, save_path) # read index file - with open(index_file, "r") as file: + with open(index_file) as file: index_file_data = json.load(file) weight_map: dict[str, str] = index_file_data["weight_map"] diff --git a/src/llmcompressor/entrypoints/model_free/save_utils.py b/src/llmcompressor/entrypoints/model_free/save_utils.py index 6d7ad2908b..49c5fc3d38 100644 --- a/src/llmcompressor/entrypoints/model_free/save_utils.py +++ b/src/llmcompressor/entrypoints/model_free/save_utils.py @@ -50,7 +50,7 @@ def update_config( # write results to config.json file config_file_path = find_config_path(save_directory) if config_file_path is not None: - with open(config_file_path, "r") as file: + with open(config_file_path) as file: config_data = json.load(file) config_data[QUANTIZATION_CONFIG_NAME] = qconfig_data diff --git a/src/llmcompressor/entrypoints/model_free/validate.py b/src/llmcompressor/entrypoints/model_free/validate.py index c27b782987..b180cae2a9 100644 --- a/src/llmcompressor/entrypoints/model_free/validate.py +++ b/src/llmcompressor/entrypoints/model_free/validate.py @@ -57,7 +57,7 @@ def validate_safetensors_index(model_files: dict[str, str], scheme: Quantization return if is_microscale_scheme(scheme): - with open(index_file_path, "r") as file: + with open(index_file_path) as file: weight_map: dict[str, str] = json.load(file)["weight_map"] file_map = invert_mapping(weight_map) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 9ab1df68e9..019a9da43f 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -12,7 +12,9 @@ import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING + +from collections.abc import Callable from loguru import logger from torch.utils.data import DataLoader diff --git a/src/llmcompressor/logger.py b/src/llmcompressor/logger.py index 2e7c378067..38c3b47c60 100644 --- a/src/llmcompressor/logger.py +++ b/src/llmcompressor/logger.py @@ -54,13 +54,13 @@ class LoggerConfig: disabled: bool = False clear_loggers: bool = True - console_log_level: Optional[str] = "INFO" - log_file: Optional[str] = None - log_file_level: Optional[str] = None + console_log_level: str | None = "INFO" + log_file: str | None = None + log_file_level: str | None = None metrics_disabled: bool = False -def configure_logger(config: Optional[LoggerConfig] = None) -> None: +def configure_logger(config: LoggerConfig | None = None) -> None: """ Configure the logger for LLM Compressor. @@ -122,7 +122,7 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None: logger.level("METRIC", no=38, color="", icon="📈") -def support_log_once(record: Dict[str, Any]) -> bool: +def support_log_once(record: dict[str, Any]) -> bool: """ Support logging only once using `.bind(log_once=True)` diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index c3d26bb3d4..f51b9f0cfe 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -1,4 +1,4 @@ -from typing import Iterable +from collections.abc import Iterable import torch from compressed_tensors import ( diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index 44258a3927..a480f51565 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -109,10 +109,10 @@ def copy_from_fused_weights( def forward( self, hidden_states: torch.Tensor, # [B, T, H] - router_indices: Optional[ + router_indices: None | ( torch.Tensor - ] = None, # [B, T, top_k] or [tokens, top_k] - routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E] + ) = None, # [B, T, top_k] or [tokens, top_k] + routing_weights: torch.Tensor | None = None, # [B, T, E] or [tokens, E] ) -> torch.Tensor: """ Implements the MoE computation using the router outputs. @@ -192,11 +192,11 @@ def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> setattr(parent, parts[-1], new_module) -def find_experts(model: nn.Module) -> List[ExpertMeta]: +def find_experts(model: nn.Module) -> list[ExpertMeta]: """ Locate GPT-OSS MoE expert modules under model.model.layers[*].mlp.experts. """ - metas: List[ExpertMeta] = [] + metas: list[ExpertMeta] = [] for li, layer in enumerate(model.model.layers): experts = layer.mlp.experts device = next(experts.parameters(), torch.zeros(())).device diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 9145c66a60..14d0cb6cf5 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -46,7 +46,7 @@ def __init__( self.shared_expert = original.shared_expert self.calibrate_all_experts = calibrate_all_experts - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_scores, router_logits = self.router(hidden_states) out = self.shared_expert(hidden_states) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 890ac32c98..2442e30bb3 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. # All rights reserved. # diff --git a/src/llmcompressor/modifiers/autoround/base.py b/src/llmcompressor/modifiers/autoround/base.py index bf9911c514..62829e8fc8 100644 --- a/src/llmcompressor/modifiers/autoround/base.py +++ b/src/llmcompressor/modifiers/autoround/base.py @@ -147,17 +147,17 @@ class AutoRoundModifier(Modifier, QuantizationMixin): Defaults to None. """ - sequential_targets: Union[str, List[str], None] = None + sequential_targets: str | list[str] | None = None # AutoRound modifier arguments iters: int = 200 enable_torch_compile: bool = True batch_size: int = 8 - lr: Optional[float] = None - device_ids: Optional[str] = None + lr: float | None = None + device_ids: str | None = None # private variables - _all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict) - _q_input: Optional[torch.Tensor] = PrivateAttr(default=None) + _all_module_input: dict[str, list[tuple]] = PrivateAttr(default_factory=dict) + _q_input: torch.Tensor | None = PrivateAttr(default=None) def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -338,7 +338,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> List[str]: + def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> list[str]: unquantized_layers = [] for name, module in wrapped_model.named_modules(): diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 7a22d3a169..57ec1b865a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,6 +1,8 @@ import inspect from itertools import product -from typing import Iterator, Literal +from typing import Literal + +from collections.abc import Iterator import torch from compressed_tensors.quantization import ( @@ -335,12 +337,12 @@ def _set_resolved_mappings(self, model: Module) -> None: resolved_mappings: list[ResolvedMapping] = [] module_to_name = get_module_to_name_dict(model) # Get names of modules targeted for quantization (excludes ignored) - targeted_names = set( + targeted_names = { name for name, _ in match_named_modules( model, self.resolved_targets, self.ignore ) - ) + } for mapping in self.mappings: # we deliberately don't use the ignore list when matching mappings, # so that we can handle layers that need smoothing but not quantization diff --git a/src/llmcompressor/modifiers/gptq/base.py b/src/llmcompressor/modifiers/gptq/base.py index 750d51aef6..c03f9b76d4 100644 --- a/src/llmcompressor/modifiers/gptq/base.py +++ b/src/llmcompressor/modifiers/gptq/base.py @@ -118,17 +118,17 @@ class GPTQModifier(Modifier, QuantizationMixin): """ # gptq modifier arguments - sequential_targets: Union[str, List[str], None] = None + sequential_targets: str | list[str] | None = None block_size: int = 128 - dampening_frac: Optional[float] = 0.01 + dampening_frac: float | None = 0.01 # TODO: this does not serialize / will be incorrectly written - actorder: Optional[Union[ActivationOrdering, Sentinel]] = Sentinel("static") + actorder: ActivationOrdering | Sentinel | None = Sentinel("static") offload_hessians: bool = False # private variables - _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) - _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) - _num_samples: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr( + _module_names: dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + _hessians: dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) + _num_samples: dict[torch.nn.Module, torch.Tensor] = PrivateAttr( default_factory=dict ) @@ -235,7 +235,7 @@ def on_event(self, state: State, event: Event, **kwargs): def calibrate_module( self, module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): """ diff --git a/src/llmcompressor/modifiers/pruning/constant/base.py b/src/llmcompressor/modifiers/pruning/constant/base.py index 94b804deec..1f79530007 100644 --- a/src/llmcompressor/modifiers/pruning/constant/base.py +++ b/src/llmcompressor/modifiers/pruning/constant/base.py @@ -14,8 +14,8 @@ class ConstantPruningModifier(Modifier, LayerParamMasking): - targets: Union[str, List[str]] - parameterized_layers_: Dict[str, ModelParameterizedLayer] = None + targets: str | list[str] + parameterized_layers_: dict[str, ModelParameterizedLayer] = None _epsilon: float = 10e-9 _save_masks: bool = False _use_hooks: bool = False diff --git a/src/llmcompressor/modifiers/pruning/helpers.py b/src/llmcompressor/modifiers/pruning/helpers.py index 39f431383b..afd2bc16c2 100644 --- a/src/llmcompressor/modifiers/pruning/helpers.py +++ b/src/llmcompressor/modifiers/pruning/helpers.py @@ -9,7 +9,9 @@ import math import re from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Dict + +from collections.abc import Callable from llmcompressor.core import Event, State @@ -34,7 +36,7 @@ class PruningCreateSettings: update: float init_sparsity: float final_sparsity: float - args: Dict[str, Any] + args: dict[str, Any] SchedulerCalculationType = Callable[[Event, State], float] diff --git a/src/llmcompressor/modifiers/pruning/magnitude/base.py b/src/llmcompressor/modifiers/pruning/magnitude/base.py index 6e873ed3ac..5456a232e1 100644 --- a/src/llmcompressor/modifiers/pruning/magnitude/base.py +++ b/src/llmcompressor/modifiers/pruning/magnitude/base.py @@ -22,16 +22,16 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking): - targets: Union[str, List[str]] + targets: str | list[str] init_sparsity: float final_sparsity: float update_scheduler: str = "cubic" - scheduler_args: Dict[str, Any] = {} + scheduler_args: dict[str, Any] = {} mask_structure: str = "unstructured" leave_enabled: bool = False apply_globally: bool = False - parameterized_layers_: Dict[str, ModelParameterizedLayer] = None + parameterized_layers_: dict[str, ModelParameterizedLayer] = None _save_masks: bool = False _use_hooks: bool = False scheduler_function_: SchedulerCalculationType = None diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py index d53a9e0763..4c7abe4b02 100644 --- a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py @@ -280,11 +280,11 @@ def save_acts(_module, input: tuple[Any, ...] | torch.Tensor, name: str): input = input[0] acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt() - hooks = set( + hooks = { self.register_hook(mod, partial(save_acts, name=name), "forward_pre") for name, mod in model.named_modules() if isinstance(mod, torch.nn.Linear) and "lm_head" not in name - ) + } with HooksMixin.disable_hooks(keep=hooks): run_calibration(model, dataloader) self.remove_hooks(hooks) diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py index f327a4c34d..5a0ebec136 100644 --- a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py @@ -9,7 +9,7 @@ def make_empty_hessian( - module: torch.nn.Module, device: Optional[torch.device] = None + module: torch.nn.Module, device: torch.device | None = None ) -> torch.Tensor: weight = module.weight num_columns = weight.shape[1] @@ -22,7 +22,7 @@ def accumulate_hessian( module: torch.nn.Module, H: torch.Tensor, num_samples: int, -) -> Tuple[torch.Tensor, int]: +) -> tuple[torch.Tensor, int]: inp = inp.to(device=H.device) if len(inp.shape) == 2: inp = inp.unsqueeze(0) @@ -58,7 +58,7 @@ def accumulate_hessian( def sparsify_weight( module: torch.nn.Module, - hessians_dict: Dict[torch.nn.Module, torch.Tensor], + hessians_dict: dict[torch.nn.Module, torch.Tensor], sparsity: float, prune_n: int, prune_m: int, diff --git a/src/llmcompressor/modifiers/pruning/wanda/wanda_sparsify.py b/src/llmcompressor/modifiers/pruning/wanda/wanda_sparsify.py index 4147ba525a..8495fbe418 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/wanda_sparsify.py +++ b/src/llmcompressor/modifiers/pruning/wanda/wanda_sparsify.py @@ -7,7 +7,7 @@ def make_empty_row_scalars( - module: torch.nn.Module, device: Optional[torch.device] = None + module: torch.nn.Module, device: torch.device | None = None ) -> torch.Tensor: weight = module.weight num_columns = weight.shape[1] @@ -55,7 +55,7 @@ def accumulate_row_scalars( def sparsify_weight( module: torch.nn.Module, - row_scalars_dict: Dict[torch.nn.Module, torch.Tensor], + row_scalars_dict: dict[torch.nn.Module, torch.Tensor], sparsity: float, prune_n: int, prune_m: int, diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index afc42c0ca6..9455dcd0cf 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -87,7 +87,7 @@ def initialize_observer( def call_observer( module: Module, base_name: str, - value: Optional[torch.Tensor] = None, + value: torch.Tensor | None = None, should_calculate_gparam: bool = False, should_calculate_qparams: bool = True, ): diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 08f9d75842..816c89e7c6 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -112,26 +112,26 @@ class QuantizationMixin(HooksMixin): (e.g. vLLM) supports non-divisible dimensions. Defaults to False. """ - config_groups: Optional[Dict[str, QuantizationScheme]] = None + config_groups: dict[str, QuantizationScheme] | None = None # NOTE: targets is not the sole source of truth for finding all matching target # layers in a model. Additional information can be stored in `config_groups` # Use self.resolved_targets as source of truth. - targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) - ignore: List[str] = Field(default_factory=list) - scheme: Optional[Union[str, Dict[str, Any]]] = None - kv_cache_scheme: Optional[QuantizationArgs] = None + targets: str | list[str] = Field(default_factory=lambda: ["Linear"]) + ignore: list[str] = Field(default_factory=list) + scheme: str | dict[str, Any] | None = None + kv_cache_scheme: QuantizationArgs | None = None # Observer parameters for easy specification - weight_observer: Optional[str] = None - input_observer: Optional[str] = None - output_observer: Optional[str] = None - observer: Optional[Dict[str, str]] = None + weight_observer: str | None = None + input_observer: str | None = None + output_observer: str | None = None + observer: dict[str, str] | None = None bypass_divisibility_checks: bool = False - _calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set) - _resolved_config: Optional[QuantizationConfig] = PrivateAttr(None) + _calibration_hooks: set[RemovableHandle] = PrivateAttr(default_factory=set) + _resolved_config: QuantizationConfig | None = PrivateAttr(None) @field_validator("targets", mode="before") - def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: + def validate_targets(cls, value: str | list[str]) -> list[str]: if isinstance(value, str): return [value] @@ -139,8 +139,8 @@ def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: @field_validator("scheme", mode="before") def validate_scheme( - cls, value: Optional[Union[str, Dict[str, Any]]] - ) -> Optional[Union[str, Dict[str, Any]]]: + cls, value: str | dict[str, Any] | None + ) -> str | dict[str, Any] | None: if isinstance(value, str) and not is_preset_scheme(value): raise ValueError( "`scheme` must either be a preset scheme name or a dictionary " @@ -157,7 +157,7 @@ def validate_scheme( return value @field_validator("observer", mode="before") - def validate_observer(cls, value: Any) -> Optional[Dict[str, str]]: + def validate_observer(cls, value: Any) -> dict[str, str] | None: """ Validate observer dictionary format. Accepts keys: 'weights', 'input', 'output' """ @@ -189,7 +189,7 @@ def resolved_config(self) -> QuantizationConfig: return self._resolved_config @property - def resolved_targets(self) -> Set[str]: + def resolved_targets(self) -> set[str]: """ Set of all resolved targets, i.e. all unique targets listed in resolved quantization config. @@ -404,7 +404,7 @@ def _initialize_observers(self, module: torch.nn.Module): if output: initialize_observer(module, base_name="output") - def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]: + def _initialize_hooks(self, module: torch.nn.Module) -> set[RemovableHandle]: hooks = set() if not hasattr(module, "quantization_scheme"): return hooks diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index 3e58d7fe4a..74dfbf7f30 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -66,20 +66,20 @@ class QuIPModifier(Modifier): :param transform_config: Optional transform config for overriding provided arguments """ # noqa: E501 - rotations: List[Literal["v", "u"]] = Field(default_factory=lambda: ["v", "u"]) + rotations: list[Literal["v", "u"]] = Field(default_factory=lambda: ["v", "u"]) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="random-hadamard" ) - targets: Union[List[str], str] = Field(default="Linear") + targets: list[str] | str = Field(default="Linear") randomize: bool = Field(default=False) learnable: bool = Field(default=False) precision: TorchDtype = Field(default=torch.float64) - transform_block_size: Optional[int] = Field(default=None) - ignore: Union[str, List[str]] = Field(default="lm_head") + transform_block_size: int | None = Field(default=None) + ignore: str | list[str] = Field(default="lm_head") # optional override for more fine-grained control # also included in recipe serialization - transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + transform_config: TransformConfig | None = Field(default=None, repr=False) @field_validator("randomize", "learnable", mode="before") def validate_not_implemented(cls, value, info: ValidationInfo): diff --git a/src/llmcompressor/modifiers/transform/smoothquant/base.py b/src/llmcompressor/modifiers/transform/smoothquant/base.py index bd3229a652..c73d72f63b 100644 --- a/src/llmcompressor/modifiers/transform/smoothquant/base.py +++ b/src/llmcompressor/modifiers/transform/smoothquant/base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable +from collections.abc import Callable import torch from compressed_tensors.offload import update_offload_parameter @@ -204,9 +204,9 @@ def _resolve_mappings(self, model: Module) -> list[SmoothQuantMapping]: # Get names of modules that are not ignored ignored_names = set() if self.ignore: - ignored_names = set( + ignored_names = { name for name, _ in match_named_modules(model, self.ignore) - ) + } for mapping in self.mappings: # we deliberately don't use the ignore list when matching mappings diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 4c52f68095..96097485bf 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,5 +1,7 @@ from enum import Enum -from typing import Iterable, List, Literal, Optional +from typing import List, Literal, Optional + +from collections.abc import Iterable import torch from compressed_tensors import match_modules_set, match_named_modules @@ -86,23 +88,23 @@ class SpinQuantModifier(Modifier, use_enum_values=True): :param transform_config: Optional transform config for overriding provided arguments """ - rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + rotations: list[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="hadamard" ) randomize: bool = Field(default=False) learnable: bool = Field(default=False) precision: TorchDtype = Field(default=torch.float64) - transform_block_size: Optional[int] = Field(default=None) + transform_block_size: int | None = Field(default=None) # norm mappings separate from spinquant mappings to allow users to # override spinquant mappings with transform_config without overriding norms - mappings: Optional[SpinQuantMapping] = Field( + mappings: SpinQuantMapping | None = Field( default=None, repr=False, exclude=True, ) - norm_mappings: Optional[List[NormMapping]] = Field( + norm_mappings: list[NormMapping] | None = Field( default=None, repr=False, exclude=True, @@ -110,7 +112,7 @@ class SpinQuantModifier(Modifier, use_enum_values=True): # optional override for more fine-grained control # also included in recipe serialization - transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + transform_config: TransformConfig | None = Field(default=None, repr=False) @field_validator("randomize", "learnable", mode="before") def validate_not_implemented(cls, value, info: ValidationInfo): diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index da3d76f6c1..68e71270d1 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -35,10 +35,10 @@ class SpinQuantMapping(BaseModel): attn_k: str attn_v: str attn_o: str - attn_head_dim: Optional[int] = Field(default=None) + attn_head_dim: int | None = Field(default=None) - mlp_in: List[str] # up_proj, gate_proj - mlp_out: List[str] # down_proj + mlp_in: list[str] # up_proj, gate_proj + mlp_out: list[str] # down_proj lm_head: str @@ -63,7 +63,7 @@ def cast_to_list(cls, value): ) -SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = { +SPINQUANT_MAPPING_REGISTRY: dict[str, SpinQuantMapping] = { "LlamaForCausalLM": _default_mappings, } diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py index e60ac0d1af..552e9b1590 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -20,7 +20,7 @@ class NormMapping(BaseModel): """ norm: str - linears: List[str] + linears: list[str] @field_validator("linears", mode="before") def cast_to_list(cls, value): @@ -45,12 +45,12 @@ def cast_to_list(cls, value): ), ] -NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { +NORM_MAPPING_REGISTRY: dict[str, NormMapping] = { "LlamaForCausalLM": _default_mappings, } -def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]: +def infer_norm_mapping_from_model(model: PreTrainedModel) -> list[NormMapping]: architecture = model.__class__.__name__ if architecture not in NORM_MAPPING_REGISTRY: logger.info( diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index c3a81d4344..1cfa43d4a9 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,6 +1,8 @@ import contextlib from functools import partial, wraps -from typing import Any, Callable, ClassVar +from typing import Any, ClassVar + +from collections.abc import Callable import torch from compressed_tensors.modeling import ( diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 384bbf6ead..9ce64bdef9 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -13,8 +13,8 @@ __all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple"] -MinMaxTuple = Tuple[torch.Tensor, torch.Tensor] -ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor] +MinMaxTuple = tuple[torch.Tensor, torch.Tensor] +ScaleZpTuple = tuple[torch.Tensor, torch.Tensor] class Observer(InternalModule, RegistryMixin): @@ -41,7 +41,7 @@ def __init__( self, base_name: str, args: QuantizationArgs, - module: Optional[torch.nn.Module] = None, + module: torch.nn.Module | None = None, **observer_kwargs, ): super().__init__() @@ -100,7 +100,7 @@ def get_global_scale(self, observed: torch.Tensor) -> torch.Tensor: def _forward_with_minmax( self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: g_idx = self._get_module_param("g_idx") global_scale = self._get_module_param("global_scale") self._check_has_global_scale(global_scale) @@ -118,7 +118,7 @@ def _forward_with_minmax( def _get_global_scale_with_minmax( self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: observed = observed.reshape((1, 1, -1)) # per tensor reshape global_min_vals, global_max_vals = self.get_global_min_max(observed) @@ -126,14 +126,14 @@ def _get_global_scale_with_minmax( return global_scale, global_min_vals, global_max_vals - def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: + def _get_module_param(self, name: str) -> torch.nn.Parameter | None: if self.module is None or (module := self.module()) is None: return None with align_module_device(module): return getattr(module, f"{self.base_name}_{name}", None) - def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): + def _check_has_global_scale(self, global_scale: torch.nn.Parameter | None): if ( self.args.strategy == QuantizationStrategy.TENSOR_GROUP and global_scale is None diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index 5a8001e461..f80b404d10 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -23,7 +23,7 @@ def flatten_for_calibration( value: torch.Tensor, base_name: str, args: QuantizationArgs, - g_idx: Optional[torch.Tensor] = None, + g_idx: torch.Tensor | None = None, ) -> torch.Tensor: """ Reshapes the value according to the quantization strategy for the purposes of @@ -57,7 +57,7 @@ def flatten_for_calibration( def _flatten_weight( - value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None + value: torch.Tensor, args: QuantizationArgs, g_idx: torch.Tensor | None = None ): # value.shape = (num_rows, num_cols) diff --git a/src/llmcompressor/observers/moving_base.py b/src/llmcompressor/observers/moving_base.py index f94c474284..9a6a6a5eca 100644 --- a/src/llmcompressor/observers/moving_base.py +++ b/src/llmcompressor/observers/moving_base.py @@ -24,7 +24,7 @@ def __init__( self, base_name: str, args: QuantizationArgs, - module: Optional[torch.nn.Module] = None, + module: torch.nn.Module | None = None, **observer_kwargs, ): super().__init__(base_name, args, module, **observer_kwargs) diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index f21c675ab6..8ff9261838 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -154,7 +154,7 @@ def _grid_search_mse( patience: float, grid: float, norm: float, - global_scale: Optional[torch.Tensor] = None, + global_scale: torch.Tensor | None = None, optimize_global_scale: bool = False, ) -> MinMaxTuple: """ diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 62998be410..a09b969d22 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -4,7 +4,9 @@ import warnings from collections import defaultdict from dataclasses import dataclass, fields, is_dataclass -from typing import Any, Generator +from typing import Any + +from collections.abc import Generator from weakref import WeakKeyDictionary import torch @@ -22,7 +24,7 @@ class IntermediateValue: otherwise None """ - value: torch.Tensor | "IntermediateValue" | Any + value: torch.Tensor | IntermediateValue | Any device: torch.device | None @@ -162,7 +164,7 @@ def size(self) -> dict[torch.device, int]: :return: dictionary mapping torch device to number of bytes in cache """ - sizes = defaultdict(lambda: 0) + sizes = defaultdict(int) memo = set() def _size_helper(intermediate: IntermediateValue) -> int: @@ -192,11 +194,11 @@ def _size_helper(intermediate: IntermediateValue) -> int: return dict(sizes) - def iter(self, input_names: list[str] | None = None) -> Generator[Any, None, None]: + def iter(self, input_names: list[str] | None = None) -> Generator[Any]: for batch_index in range(len(self.batch_intermediates)): yield self.fetch(batch_index, input_names) - def __iter__(self) -> Generator[Any, None, None]: + def __iter__(self) -> Generator[Any]: yield from self.iter() def __len__(self) -> int: diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py index dc295952b8..74b137ab95 100644 --- a/src/llmcompressor/pipelines/data_free/pipeline.py +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -18,7 +18,7 @@ class DataFreePipeline(CalibrationPipeline): @staticmethod def __call__( model: torch.nn.Module, - dataloader: Optional[DataLoader], + dataloader: DataLoader | None, dataset_args: "DatasetArguments", ): """ diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index fbbd7c9d51..3a0bd5cf61 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -4,7 +4,9 @@ from dataclasses import dataclass from functools import wraps from types import FunctionType, MethodType -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional + +from collections.abc import Callable import torch from accelerate.hooks import remove_hook_from_module @@ -75,9 +77,9 @@ def forward(self, *args, **kwargs) -> dict[str, Any]: def submodules(self, model: Module, recurse: bool = False) -> set[Module]: nodes = self.graph.find_nodes(op="call_module") - modules = set(model.get_submodule(node.target) for node in nodes) + modules = {model.get_submodule(node.target) for node in nodes} if recurse: - modules = set(m for module in modules for m in module.modules()) + modules = {m for module in modules for m in module.modules()} return modules @@ -101,9 +103,9 @@ def trace_subgraphs( :return: a list of Subgraphs in order of execution """ # find modules - targets = set( + targets = { module for _, module in match_named_modules(model, sequential_targets) - ) + } ancestors = get_sequential_ancestors(model, targets) offloaded = set() # TODO: cleanup logic @@ -253,11 +255,11 @@ def find_target_nodes(graph: GraphModule, targets: set[Module]) -> set[Node]: :param targets: modules whose nodes are being searched for :return: set of all nodes which call the target modules """ - return set( + return { node for node in graph.graph.nodes if node.op == "call_module" and graph.get_submodule(node.target) in targets - ) + } def topological_partition(graph: GraphModule, targets: set[Module]) -> list[list[Node]]: @@ -371,7 +373,7 @@ def partition_graph(model: Module, partitions: list[list[Node]]) -> list[Subgrap # save the subgraph for this partition graph.lint() - input_names = set(node.name for node in graph.nodes if node.op == "placeholder") + input_names = {node.name for node in graph.nodes if node.op == "placeholder"} subgraphs.append( Subgraph( graph=graph, @@ -518,8 +520,8 @@ def is_ancestor(module: Module) -> bool: def dispatch_for_sequential( model: PreTrainedModel, - onload_device: Optional[torch.device | str] = None, - offload_device: Optional[torch.device | str] = None, + onload_device: torch.device | str | None = None, + offload_device: torch.device | str | None = None, ) -> PreTrainedModel: """ Dispatch a model for sequential calibration using a sequential pipeline. diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index a16693b1a0..6e139bf94d 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,6 +1,8 @@ import contextlib from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Iterator +from typing import TYPE_CHECKING + +from collections.abc import Iterator import torch from compressed_tensors.utils import disable_offloading diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 190969085c..ce3507f7d5 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -13,7 +13,7 @@ ] -def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype: +def parse_dtype(dtype_arg: str | torch.dtype) -> torch.dtype: """ :param dtype_arg: dtype or string to parse :return: torch.dtype parsed from input string @@ -30,7 +30,7 @@ def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype: return dtype -def get_session_model() -> Optional[Module]: +def get_session_model() -> Module | None: """ :return: pytorch module stored by the active CompressionSession, or None if no session is active diff --git a/src/llmcompressor/pytorch/utils/sparsification.py b/src/llmcompressor/pytorch/utils/sparsification.py index ccc138308e..24bb1cf1be 100644 --- a/src/llmcompressor/pytorch/utils/sparsification.py +++ b/src/llmcompressor/pytorch/utils/sparsification.py @@ -29,7 +29,7 @@ class ModuleSparsificationInfo: """ def __init__( - self, module: Module, state_dict: Optional[Dict[str, torch.Tensor]] = None + self, module: Module, state_dict: dict[str, torch.Tensor] | None = None ): self.module = module diff --git a/src/llmcompressor/pytorch/utils/sparsification_info/configs.py b/src/llmcompressor/pytorch/utils/sparsification_info/configs.py index 8ab22a6f95..dd5e8c8cb9 100644 --- a/src/llmcompressor/pytorch/utils/sparsification_info/configs.py +++ b/src/llmcompressor/pytorch/utils/sparsification_info/configs.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from collections import Counter, defaultdict -from typing import Any, Dict, Generator, Tuple, Union +from typing import Any, Dict, Tuple, Union + +from collections.abc import Generator import torch.nn from pydantic import BaseModel, ConfigDict, Field @@ -40,7 +42,7 @@ def from_module( def loggable_items( self, **kwargs, - ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + ) -> Generator[tuple[str, dict[str, int] | float | int], None, None]: """ Yield the loggable items for SparsificationInfo object. @@ -50,7 +52,7 @@ def loggable_items( @staticmethod def filter_loggable_items_percentages_only( - items_to_log: Generator[Tuple[str, Any], None, None], + items_to_log: Generator[tuple[str, Any], None, None], percentage_only: bool = False, ): """ @@ -128,11 +130,11 @@ class SparsificationSummaries(SparsificationInfo): description="A model that contains the number of " "parameters/the percent of parameters that are pruned." ) - parameter_counts: Dict[str, int] = Field( + parameter_counts: dict[str, int] = Field( description="A dictionary that maps the name of a parameter " "to the number of elements (weights) in that parameter." ) - operation_counts: Dict[str, int] = Field( + operation_counts: dict[str, int] = Field( description="A dictionary that maps the name of an operation " "to the number of times that operation is used in the model." ) @@ -141,7 +143,7 @@ class SparsificationSummaries(SparsificationInfo): def from_module( cls, module=torch.nn.Module, - pruning_thresholds: Tuple[float, float] = (0.05, 1 - 1e-9), + pruning_thresholds: tuple[float, float] = (0.05, 1 - 1e-9), ) -> "SparsificationSummaries": """ Factory method to create a SparsificationSummaries object from a module. @@ -192,7 +194,7 @@ def loggable_items( non_zero_only: bool = False, percentages_only: bool = True, **kwargs, - ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + ) -> Generator[tuple[str, dict[str, int] | float | int], None, None]: """ Yield the loggable items for SparsificationSummaries object. @@ -227,7 +229,7 @@ class SparsificationPruning(SparsificationInfo): A model that contains the pruning information for a torch module. """ - sparse_parameters: Dict[str, CountAndPercent] = Field( + sparse_parameters: dict[str, CountAndPercent] = Field( description="A dictionary that maps the name of a parameter " "to the number/percent of weights that are zeroed out " "in that layer." @@ -261,7 +263,7 @@ def loggable_items( percentages_only: bool = False, non_zero_only: bool = False, **kwargs, - ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + ) -> Generator[tuple[str, dict[str, int] | float | int], None, None]: """ Yield the loggable items for SparsificationPruning object. @@ -302,12 +304,12 @@ class SparsificationQuantization(SparsificationInfo): A model that contains the quantization information for a torch module. """ - enabled: Dict[str, bool] = Field( + enabled: dict[str, bool] = Field( description="A dictionary that maps the name of an " "operation to a boolean flag that indicates whether " "the operation is quantized or not." ) - precision: Dict[str, Union[BaseModel, None, int]] = Field( + precision: dict[str, BaseModel | None | int] = Field( description="A dictionary that maps the name of a layer" "to the precision of that layer." ) @@ -344,7 +346,7 @@ def loggable_items( self, enabled_only: bool = False, **kwargs, - ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + ) -> Generator[tuple[str, dict[str, int] | float | int], None, None]: """ Yield the loggable items for SparsificationQuantization object. diff --git a/src/llmcompressor/pytorch/utils/sparsification_info/helpers.py b/src/llmcompressor/pytorch/utils/sparsification_info/helpers.py index 78a5515cba..ce861811c6 100644 --- a/src/llmcompressor/pytorch/utils/sparsification_info/helpers.py +++ b/src/llmcompressor/pytorch/utils/sparsification_info/helpers.py @@ -9,9 +9,9 @@ def get_leaf_operations( model: torch.nn.Module, - operations_to_skip: Optional[List[torch.nn.Module]] = None, - operations_to_unwrap: Optional[List[torch.nn.Module]] = None, -) -> List[torch.nn.Module]: + operations_to_skip: list[torch.nn.Module] | None = None, + operations_to_unwrap: list[torch.nn.Module] | None = None, +) -> list[torch.nn.Module]: """ Get the leaf operations in the model (those that do not have operations as children) @@ -106,4 +106,4 @@ def _get_num_bits(dtype: torch.dtype) -> int: elif dtype == torch.int64: return 64 else: - raise ValueError("Unknown dtype: {}".format(dtype)) + raise ValueError(f"Unknown dtype: {dtype}") diff --git a/src/llmcompressor/pytorch/utils/sparsification_info/module_sparsification_info.py b/src/llmcompressor/pytorch/utils/sparsification_info/module_sparsification_info.py index 0e7aa0c4b5..8500fcddaa 100644 --- a/src/llmcompressor/pytorch/utils/sparsification_info/module_sparsification_info.py +++ b/src/llmcompressor/pytorch/utils/sparsification_info/module_sparsification_info.py @@ -1,4 +1,6 @@ -from typing import Any, Generator, Tuple +from typing import Any, Tuple + +from collections.abc import Generator import torch from pydantic import Field @@ -36,7 +38,7 @@ def from_module(cls, module: torch.nn.Module) -> "ModuleSparsificationInfo": """ if not isinstance(module, torch.nn.Module): raise ValueError( - "Module must be a torch.nn.Module, not {}".format(type(module)) + f"Module must be a torch.nn.Module, not {type(module)}" ) return cls( @@ -45,7 +47,7 @@ def from_module(cls, module: torch.nn.Module) -> "ModuleSparsificationInfo": quantization_info=SparsificationQuantization.from_module(module), ) - def loggable_items(self, **kwargs) -> Generator[Tuple[str, Any], None, None]: + def loggable_items(self, **kwargs) -> Generator[tuple[str, Any], None, None]: """ A generator that yields the loggable items of the ModuleSparsificationInfo object. diff --git a/src/llmcompressor/recipe/recipe.py b/src/llmcompressor/recipe/recipe.py index 6ddf6cbca0..c456261ce2 100644 --- a/src/llmcompressor/recipe/recipe.py +++ b/src/llmcompressor/recipe/recipe.py @@ -34,17 +34,17 @@ class Recipe(BaseModel): when serializing a recipe, yaml will be used by default. """ - args: Dict[str, Any] = Field(default_factory=dict) + args: dict[str, Any] = Field(default_factory=dict) stage: str = "default" - modifiers: List[Modifier] = Field(default_factory=list) + modifiers: list[Modifier] = Field(default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) @classmethod def from_modifiers( cls, - modifiers: Union[Modifier, List[Modifier]], - modifier_group_name: Optional[str] = None, + modifiers: Modifier | list[Modifier], + modifier_group_name: str | None = None, ) -> "Recipe": """ Create a recipe instance from a list of modifiers @@ -84,9 +84,9 @@ def from_modifiers( @classmethod def create_instance( cls, - path_or_modifiers: Union[str, Modifier, List[Modifier], "Recipe"], - modifier_group_name: Optional[str] = None, - target_stage: Optional[str] = None, + path_or_modifiers: Union[str, Modifier, list[Modifier], "Recipe"], + modifier_group_name: str | None = None, + target_stage: str | None = None, ) -> "Recipe": """ Create a recipe instance from a file, string, or RecipeModifier objects @@ -139,7 +139,7 @@ def create_instance( else: logger.info(f"Loading recipe from file {path_or_modifiers}") - with open(path_or_modifiers, "r") as file: + with open(path_or_modifiers) as file: content = file.read().strip() if path_or_modifiers.lower().endswith(".md"): content = _parse_recipe_from_md(path_or_modifiers, content) @@ -160,7 +160,7 @@ def create_instance( return cls.from_dict(filter_dict(obj, target_stage=target_stage)) @classmethod - def from_dict(cls, recipe_dict: Dict[str, Any]) -> "Recipe": + def from_dict(cls, recipe_dict: dict[str, Any]) -> "Recipe": """ Parses a dictionary representing a recipe and returns a Recipe instance. Ensures all modifier entries are instantiated Modifier objects. @@ -169,7 +169,7 @@ def from_dict(cls, recipe_dict: Dict[str, Any]) -> "Recipe": :return: Recipe instance with instantiated Modifier objects. """ args = recipe_dict.get("args", {}) - modifiers: List[Modifier] = [] + modifiers: list[Modifier] = [] stage = "default" if not ModifierFactory._loaded: @@ -198,7 +198,7 @@ def from_dict(cls, recipe_dict: Dict[str, Any]) -> "Recipe": modifiers=modifiers, ) - def dict(self, *args, **kwargs) -> Dict[str, Any]: + def dict(self, *args, **kwargs) -> dict[str, Any]: """ :return: A dictionary representation of the recipe """ @@ -207,8 +207,8 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]: def yaml( self, - file_path: Optional[str] = None, - existing_recipe_path: Optional[str] = None, + file_path: str | None = None, + existing_recipe_path: str | None = None, ) -> str: """ Return a YAML string representation of the recipe, @@ -221,7 +221,7 @@ def yaml( # Load the other recipe from file, if given existing_dict = {} if existing_recipe_path: - with open(existing_recipe_path, "r") as f: + with open(existing_recipe_path) as f: existing_recipe_str = f.read() existing_dict = _load_json_or_yaml_string(existing_recipe_str) @@ -251,6 +251,6 @@ def yaml( return yaml_str -RecipeInput = Union[str, List[str], Recipe, List[Recipe], Modifier, List[Modifier]] -RecipeStageInput = Union[str, List[str], List[List[str]]] -RecipeArgsInput = Union[Dict[str, Any], List[Dict[str, Any]]] +RecipeInput = Union[str, list[str], Recipe, list[Recipe], Modifier, list[Modifier]] +RecipeStageInput = Union[str, list[str], list[list[str]]] +RecipeArgsInput = Union[dict[str, Any], list[dict[str, Any]]] diff --git a/src/llmcompressor/recipe/utils.py b/src/llmcompressor/recipe/utils.py index 8c787ea555..fcf5b2f022 100644 --- a/src/llmcompressor/recipe/utils.py +++ b/src/llmcompressor/recipe/utils.py @@ -7,7 +7,7 @@ from llmcompressor.modifiers import Modifier -def _load_json_or_yaml_string(content: str) -> Dict[str, Any]: +def _load_json_or_yaml_string(content: str) -> dict[str, Any]: # try loading as json first, then yaml # if both fail, raise a ValueError try: @@ -48,12 +48,12 @@ def _parse_recipe_from_md(file_path, yaml_str): else: # fail if we know whe should have extracted front matter out raise RuntimeError( - "Could not extract YAML front matter from recipe card: {}".format(file_path) + f"Could not extract YAML front matter from recipe card: {file_path}" ) return yaml_str -def get_yaml_serializable_dict(modifiers: List[Modifier], stage: str) -> Dict[str, Any]: +def get_yaml_serializable_dict(modifiers: list[Modifier], stage: str) -> dict[str, Any]: """ This function is used to convert a list of modifiers into a dictionary where the keys are the group names and the values are the modifiers @@ -96,7 +96,7 @@ def get_yaml_serializable_dict(modifiers: List[Modifier], stage: str) -> Dict[st return stage_dict -def filter_dict(obj: dict, target_stage: Optional[str] = None) -> dict: +def filter_dict(obj: dict, target_stage: str | None = None) -> dict: """ Filter a dictionary to only include keys that match the target stage. diff --git a/src/llmcompressor/transformers/data/base.py b/src/llmcompressor/transformers/data/base.py index 512500a0eb..bd5b633428 100644 --- a/src/llmcompressor/transformers/data/base.py +++ b/src/llmcompressor/transformers/data/base.py @@ -10,7 +10,9 @@ import inspect from functools import cached_property from inspect import _ParameterKind as Kind -from typing import Any, Callable +from typing import Any + +from collections.abc import Callable from compressed_tensors.registry import RegistryMixin from datasets import Dataset, IterableDataset @@ -241,18 +243,18 @@ def rename_columns(self, dataset: DatasetType) -> DatasetType: def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: # assumes that inputs are not passed via self.processor.__call__ args and kwargs signature = inspect.signature(self.processor.__call__) - tokenizer_args = set( + tokenizer_args = { key for key, param in signature.parameters.items() if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD) - ) + } logger.debug( f"Found processor args `{tokenizer_args}`. Removing all other columns" ) column_names = get_columns(dataset) return dataset.remove_columns( - list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY])) + list(set(column_names) - set(tokenizer_args) - {self.PROMPT_KEY}) ) def tokenize(self, data: LazyRow) -> dict[str, Any]: diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index 0f6213815c..0e4fa2d64d 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -32,14 +32,14 @@ def parse_args(): def trace( model_id: str, - model_class: Type[PreTrainedModel], + model_class: type[PreTrainedModel], sequential_targets: list[str] | str | None = None, ignore: list[str] | str = DatasetArguments().tracing_ignore, modality: str = "text", trust_remote_code: bool = True, skip_weights: bool = True, device_map: str | dict = "cpu", -) -> Tuple[PreTrainedModel, list[Subgraph], dict[str, torch.Tensor]]: +) -> tuple[PreTrainedModel, list[Subgraph], dict[str, torch.Tensor]]: """ Debug traceability by tracing a pre-trained model into subgraphs diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index 67d2f726b2..270c0ed542 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -135,11 +135,9 @@ def recipe_from_huggingface_model_id( logger.info(f"Found recipe: {recipe_file_name} for model ID: {hf_stub}.") except Exception as e: # TODO: narrow acceptable exceptions logger.debug( - ( f"Unable to find recipe {recipe_file_name} " f"for model ID: {hf_stub}: {e}." "Skipping recipe resolution." - ) ) recipe = None diff --git a/src/llmcompressor/typing.py b/src/llmcompressor/typing.py index 233f4df56a..586bbf9fa5 100644 --- a/src/llmcompressor/typing.py +++ b/src/llmcompressor/typing.py @@ -2,7 +2,7 @@ Defines type aliases for the llm-compressor library. """ -from typing import Iterable +from collections.abc import Iterable import torch from datasets import Dataset, DatasetDict, IterableDataset diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index 244cd1489a..e507454753 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -24,7 +24,7 @@ @contextlib.contextmanager -def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM): +def skip_weights_download(model_class: type[PreTrainedModel] = AutoModelForCausalLM): """ Context manager under which models are initialized without having to download the model weight files. This differs from `init_empty_weights` in that weights are diff --git a/src/llmcompressor/utils/dist.py b/src/llmcompressor/utils/dist.py index 2339c24159..2e1682a36c 100644 --- a/src/llmcompressor/utils/dist.py +++ b/src/llmcompressor/utils/dist.py @@ -1,4 +1,6 @@ -from typing import Callable, Hashable, TypeVar +from typing import TypeVar + +from collections.abc import Callable, Hashable import torch.distributed as dist diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index d99ce3974d..e2302e741d 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -98,7 +98,7 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): if self.loss is not None: patch.log("METRIC", f"error {self.loss:.2f}") - gpu_usage: List[GPUMemory] = self.get_GPU_memory_usage() + gpu_usage: list[GPUMemory] = self.get_GPU_memory_usage() for gpu in gpu_usage: perc = gpu.pct_used * 100 patch.log( @@ -112,14 +112,14 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): compressed_size = get_layer_size_mb(self.module) patch.log("METRIC", f"Compressed module size: {compressed_size} MB") - def get_GPU_memory_usage(self) -> List[GPUMemory]: + def get_GPU_memory_usage(self) -> list[GPUMemory]: if self.gpu_type == GPUType.amd: return self._get_GPU_usage_amd(self.visible_ids) else: return self._get_GPU_usage_nv(self.visible_ids) @staticmethod - def _get_GPU_usage_nv(visible_ids: List[int]) -> List[GPUMemory]: + def _get_GPU_usage_nv(visible_ids: list[int]) -> list[GPUMemory]: """ get gpu usage for visible Nvidia GPUs using nvml lib @@ -136,7 +136,7 @@ def _get_GPU_usage_nv(visible_ids: List[int]) -> List[GPUMemory]: logger.warning(f"Pynml library error:\n {_err}") return [] - usage: List[GPUMemory] = [] + usage: list[GPUMemory] = [] if len(visible_ids) == 0: visible_ids = range(pynvml.nvmlDeviceGetCount()) @@ -155,14 +155,14 @@ def _get_GPU_usage_nv(visible_ids: List[int]) -> List[GPUMemory]: return [] @staticmethod - def _get_GPU_usage_amd(visible_ids: List[int]) -> List[GPUMemory]: + def _get_GPU_usage_amd(visible_ids: list[int]) -> list[GPUMemory]: """ get gpu usage for AMD GPUs using amdsmi lib :param visible_ids: list of GPUs to monitor. If unset or zero length, defaults to all """ - usage: List[GPUMemory] = [] + usage: list[GPUMemory] = [] try: import amdsmi diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 5a8098dd3b..d423ea170c 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -26,7 +26,7 @@ ALL_QUANTIZABLE_TARGET = "__ALL_QUANTIZABLE__" -def expand_special_targets(targets: Union[str, List[str]]) -> List[str]: +def expand_special_targets(targets: str | list[str]) -> list[str]: """ Expand special target constants to explicit class names with backward compatibility. @@ -72,9 +72,9 @@ def expand_special_targets(targets: Union[str, List[str]]) -> List[str]: def build_parameterized_layers( model: Module, - targets: Union[str, List[str]], + targets: str | list[str], param_name: str = "weight", -) -> Dict[str, ModelParameterizedLayer]: +) -> dict[str, ModelParameterizedLayer]: """ Build ModelParameterizedLayer objects for modules matching the given targets. @@ -128,7 +128,7 @@ def qat_active(module: Module) -> bool: return False -def get_no_split_params(model: PreTrainedModel) -> Union[str, List[str]]: +def get_no_split_params(model: PreTrainedModel) -> str | list[str]: """ Get list of module classes that shouldn't be split when sharding. For Hugging Face Transformer models, this is the decoder layer type. For other diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py index 093b95dbad..602da70e82 100644 --- a/src/llmcompressor/utils/transformers.py +++ b/src/llmcompressor/utils/transformers.py @@ -59,7 +59,7 @@ def targets_embeddings( ) return False - targets = set(module for _, module in targets) + targets = {module for _, module in targets} return (check_input and input_embed in targets) or ( check_output and output_embed in targets )