-
Notifications
You must be signed in to change notification settings - Fork 453
refactor: modernize type hints to Python 3.10+ syntax (src/) #2438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,7 +7,9 @@ | |||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| from dataclasses import dataclass | ||||||||||||
| from typing import Any, Callable | ||||||||||||
| from typing import Any | ||||||||||||
|
|
||||||||||||
| from collections.abc import Callable | ||||||||||||
|
|
||||||||||||
|
Comment on lines
+10
to
13
|
||||||||||||
| from typing import Any | |
| from collections.abc import Callable | |
| from collections.abc import Callable | |
| from typing import Any |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,9 @@ | |
|
|
||
| import threading | ||
| from contextlib import contextmanager | ||
| from typing import TYPE_CHECKING, Any, Generator, Optional | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
|
|
||
| from collections.abc import Generator | ||
|
|
||
|
Comment on lines
+10
to
13
|
||
| from loguru import logger | ||
|
|
||
|
|
@@ -91,7 +93,7 @@ def event(cls, event_type: EventType, **kwargs) -> ModifiedState: | |
| return active_session().event(event_type, **kwargs) | ||
|
|
||
| @classmethod | ||
| def batch_start(cls, batch_data: Optional[Any] = None, **kwargs) -> ModifiedState: | ||
| def batch_start(cls, batch_data: Any | None = None, **kwargs) -> ModifiedState: | ||
| """ | ||
| Invoke a batch start event for the active session | ||
|
|
||
|
|
@@ -102,7 +104,7 @@ def batch_start(cls, batch_data: Optional[Any] = None, **kwargs) -> ModifiedStat | |
| return cls.event(EventType.BATCH_START, batch_data=batch_data, **kwargs) | ||
|
|
||
| @classmethod | ||
| def loss_calculated(cls, loss: Optional[Any] = None, **kwargs) -> ModifiedState: | ||
| def loss_calculated(cls, loss: Any | None = None, **kwargs) -> ModifiedState: | ||
| """ | ||
| Invoke a loss calculated event for the active session | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,9 @@ | |
| import math | ||
| import re | ||
| from collections.abc import Iterator, Sized | ||
| from typing import Any, Callable, Optional | ||
| from typing import Any, Optional | ||
|
|
||
| from collections.abc import Callable | ||
|
|
||
|
Comment on lines
12
to
16
|
||
| import torch | ||
| from datasets import Dataset | ||
|
|
@@ -334,7 +336,7 @@ class LengthAwareSampler(Sampler[int]): | |
| def __init__( | ||
| self, | ||
| data_source: Dataset, | ||
| num_samples: Optional[int] = None, | ||
| num_samples: int | None = None, | ||
| batch_size: int = 1, | ||
|
Comment on lines
336
to
340
|
||
| ) -> None: | ||
| self.data_source = data_source | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -2,7 +2,9 @@ | |||
| import shutil | ||||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||||
| from pathlib import Path | ||||
| from typing import Iterable, Optional | ||||
| from typing import Optional | ||||
|
||||
| from typing import Optional |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,9 @@ | ||||||||||||
| import os | ||||||||||||
| import re | ||||||||||||
| from collections import defaultdict | ||||||||||||
| from typing import Mapping, TypeVar | ||||||||||||
| from typing import TypeVar | ||||||||||||
|
|
||||||||||||
| from collections.abc import Mapping | ||||||||||||
|
|
||||||||||||
|
Comment on lines
+4
to
7
|
||||||||||||
| from typing import TypeVar | |
| from collections.abc import Mapping | |
| from collections.abc import Mapping | |
| from typing import TypeVar |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import os | ||
| from collections import defaultdict | ||
| from collections.abc import Iterator, Mapping | ||
| from typing import Iterable | ||
| from collections.abc import Iterable | ||
|
|
||
|
Comment on lines
3
to
5
|
||
| import torch | ||
| from compressed_tensors.quantization import QuantizationScheme | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,9 @@ | |||||||||||
| import os | ||||||||||||
| from datetime import datetime | ||||||||||||
| from pathlib import Path | ||||||||||||
| from typing import TYPE_CHECKING, Callable | ||||||||||||
| from typing import TYPE_CHECKING | ||||||||||||
|
|
||||||||||||
| from collections.abc import Callable | ||||||||||||
|
|
||||||||||||
|
Comment on lines
+15
to
18
|
||||||||||||
| from typing import TYPE_CHECKING | |
| from collections.abc import Callable | |
| from collections.abc import Callable | |
| from typing import TYPE_CHECKING |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,13 +54,13 @@ | |
| class LoggerConfig: | ||
| disabled: bool = False | ||
| clear_loggers: bool = True | ||
| console_log_level: Optional[str] = "INFO" | ||
| log_file: Optional[str] = None | ||
| log_file_level: Optional[str] = None | ||
| console_log_level: str | None = "INFO" | ||
| log_file: str | None = None | ||
| log_file_level: str | None = None | ||
| metrics_disabled: bool = False | ||
|
|
||
|
|
||
| def configure_logger(config: Optional[LoggerConfig] = None) -> None: | ||
| def configure_logger(config: LoggerConfig | None = None) -> None: | ||
| """ | ||
|
Comment on lines
+57
to
64
|
||
| Configure the logger for LLM Compressor. | ||
|
|
||
|
|
@@ -122,7 +122,7 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None: | |
| logger.level("METRIC", no=38, color="<yellow>", icon="📈") | ||
|
|
||
|
|
||
| def support_log_once(record: Dict[str, Any]) -> bool: | ||
| def support_log_once(record: dict[str, Any]) -> bool: | ||
| """ | ||
| Support logging only once using `.bind(log_once=True)` | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from typing import Iterable | ||
| from collections.abc import Iterable | ||
|
|
||
| import torch | ||
| from compressed_tensors import ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,10 +109,10 @@ def copy_from_fused_weights( | |
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, # [B, T, H] | ||
| router_indices: Optional[ | ||
| router_indices: None | ( | ||
| torch.Tensor | ||
| ] = None, # [B, T, top_k] or [tokens, top_k] | ||
| routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E] | ||
| ) = None, # [B, T, top_k] or [tokens, top_k] | ||
| routing_weights: torch.Tensor | None = None, # [B, T, E] or [tokens, E] | ||
| ) -> torch.Tensor: | ||
|
Comment on lines
109
to
116
|
||
| """ | ||
| Implements the MoE computation using the router outputs. | ||
|
|
@@ -192,11 +192,11 @@ def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> | |
| setattr(parent, parts[-1], new_module) | ||
|
|
||
|
|
||
| def find_experts(model: nn.Module) -> List[ExpertMeta]: | ||
| def find_experts(model: nn.Module) -> list[ExpertMeta]: | ||
| """ | ||
| Locate GPT-OSS MoE expert modules under model.model.layers[*].mlp.experts. | ||
| """ | ||
| metas: List[ExpertMeta] = [] | ||
| metas: list[ExpertMeta] = [] | ||
| for li, layer in enumerate(model.model.layers): | ||
| experts = layer.mlp.experts | ||
| device = next(experts.parameters(), torch.zeros(())).device | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -147,17 +147,17 @@ class AutoRoundModifier(Modifier, QuantizationMixin): | |
| Defaults to None. | ||
| """ | ||
|
|
||
| sequential_targets: Union[str, List[str], None] = None | ||
| sequential_targets: str | list[str] | None = None | ||
| # AutoRound modifier arguments | ||
| iters: int = 200 | ||
| enable_torch_compile: bool = True | ||
| batch_size: int = 8 | ||
| lr: Optional[float] = None | ||
| device_ids: Optional[str] = None | ||
| lr: float | None = None | ||
| device_ids: str | None = None | ||
|
|
||
| # private variables | ||
| _all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict) | ||
| _q_input: Optional[torch.Tensor] = PrivateAttr(default=None) | ||
| _all_module_input: dict[str, list[tuple]] = PrivateAttr(default_factory=dict) | ||
| _q_input: torch.Tensor | None = PrivateAttr(default=None) | ||
|
Comment on lines
+150
to
+160
|
||
|
|
||
| def on_initialize(self, state: State, **kwargs) -> bool: | ||
| """ | ||
|
|
@@ -338,7 +338,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: | |
|
|
||
| return True | ||
|
|
||
| def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> List[str]: | ||
| def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> list[str]: | ||
| unquantized_layers = [] | ||
|
|
||
| for name, module in wrapped_model.named_modules(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,8 @@ | ||||||||||||
| import inspect | ||||||||||||
| from itertools import product | ||||||||||||
| from typing import Iterator, Literal | ||||||||||||
| from typing import Literal | ||||||||||||
|
|
||||||||||||
| from collections.abc import Iterator | ||||||||||||
|
|
||||||||||||
|
Comment on lines
+3
to
6
|
||||||||||||
| from typing import Literal | |
| from collections.abc import Iterator | |
| from collections.abc import Iterator | |
| from typing import Literal |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,17 +118,17 @@ class GPTQModifier(Modifier, QuantizationMixin): | |
| """ | ||
|
|
||
| # gptq modifier arguments | ||
| sequential_targets: Union[str, List[str], None] = None | ||
| sequential_targets: str | list[str] | None = None | ||
| block_size: int = 128 | ||
| dampening_frac: Optional[float] = 0.01 | ||
| dampening_frac: float | None = 0.01 | ||
| # TODO: this does not serialize / will be incorrectly written | ||
| actorder: Optional[Union[ActivationOrdering, Sentinel]] = Sentinel("static") | ||
| actorder: ActivationOrdering | Sentinel | None = Sentinel("static") | ||
| offload_hessians: bool = False | ||
|
Comment on lines
+121
to
126
|
||
|
|
||
| # private variables | ||
| _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) | ||
| _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) | ||
| _num_samples: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr( | ||
| _module_names: dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) | ||
| _hessians: dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) | ||
| _num_samples: dict[torch.nn.Module, torch.Tensor] = PrivateAttr( | ||
| default_factory=dict | ||
| ) | ||
|
|
||
|
|
@@ -235,7 +235,7 @@ def on_event(self, state: State, event: Event, **kwargs): | |
| def calibrate_module( | ||
| self, | ||
| module: torch.nn.Module, | ||
| args: Tuple[torch.Tensor, ...], | ||
| args: tuple[torch.Tensor, ...], | ||
| _output: torch.Tensor, | ||
| ): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,7 +9,9 @@ | |||||||||||
| import math | ||||||||||||
| import re | ||||||||||||
| from dataclasses import dataclass | ||||||||||||
| from typing import Any, Callable, Dict | ||||||||||||
| from typing import Any, Dict | ||||||||||||
|
|
||||||||||||
| from collections.abc import Callable | ||||||||||||
|
|
||||||||||||
|
Comment on lines
+12
to
15
|
||||||||||||
| from typing import Any, Dict | |
| from collections.abc import Callable | |
| from collections.abc import Callable | |
| from typing import Any, Dict |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,16 +22,16 @@ | |
|
|
||
|
|
||
| class MagnitudePruningModifier(Modifier, LayerParamMasking): | ||
| targets: Union[str, List[str]] | ||
| targets: str | list[str] | ||
| init_sparsity: float | ||
| final_sparsity: float | ||
| update_scheduler: str = "cubic" | ||
| scheduler_args: Dict[str, Any] = {} | ||
| scheduler_args: dict[str, Any] = {} | ||
| mask_structure: str = "unstructured" | ||
|
Comment on lines
24
to
30
|
||
| leave_enabled: bool = False | ||
| apply_globally: bool = False | ||
|
|
||
| parameterized_layers_: Dict[str, ModelParameterizedLayer] = None | ||
| parameterized_layers_: dict[str, ModelParameterizedLayer] = None | ||
| _save_masks: bool = False | ||
| _use_hooks: bool = False | ||
| scheduler_function_: SchedulerCalculationType = None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from typing import Optionalappears unused after converting annotations toX | None(it’s only referenced in docstrings now). This will failruff checkwith F401; please remove the unusedOptionalimport.