diff --git a/.actions/assistant.py b/.actions/assistant.py index dc1fa05..15a20e6 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -18,11 +18,10 @@ import shutil import tempfile import urllib.request -from collections.abc import Iterable, Iterator, Sequence from itertools import chain from os.path import dirname, isfile from pathlib import Path -from typing import Any, Optional +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple from packaging.requirements import Requirement from packaging.version import Version @@ -128,7 +127,7 @@ def _parse_requirements(lines: Iterable[str]) -> Iterator[_RequirementWithCommen pip_argument = None -def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]: +def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]: """Loading requirements from a file. >>> path_req = os.path.join(_PROJECT_ROOT, "requirements") @@ -223,7 +222,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme fp.writelines([ln + os.linesep for ln in requires] + [os.linesep]) -def _retrieve_files(directory: str, *ext: str) -> list[str]: +def _retrieve_files(directory: str, *ext: str) -> List[str]: all_files = [] for root, _, files in os.walk(directory): for fname in files: @@ -233,7 +232,7 @@ def _retrieve_files(directory: str, *ext: str) -> list[str]: return all_files -def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]: +def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]: """Replace imports of standalone package to lightning. >>> lns = [ @@ -321,7 +320,7 @@ def copy_replace_imports( fo.writelines(lines) -def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> None: +def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None: """Create a mirror package with adjusted imports.""" # replace imports and copy the code mapping = package_mapping.copy() diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml index b6af39d..15d226e 100644 --- a/.github/workflows/_legacy-checkpoints.yml +++ b/.github/workflows/_legacy-checkpoints.yml @@ -60,7 +60,7 @@ jobs: - uses: actions/setup-python@v5 with: # Python version here needs to be supported by all PL versions listed in back-compatible-versions.txt. - python-version: "3.9" + python-version: 3.8 - name: Install PL from source env: diff --git a/.github/workflows/call-clear-cache.yml b/.github/workflows/call-clear-cache.yml index 1dddbe8..4c18987 100644 --- a/.github/workflows/call-clear-cache.yml +++ b/.github/workflows/call-clear-cache.yml @@ -23,7 +23,7 @@ on: jobs: cron-clear: if: github.event_name == 'schedule' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.8 with: scripts-ref: v0.11.8 dry-run: ${{ github.event_name == 'pull_request' }} @@ -32,7 +32,7 @@ jobs: direct-clear: if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.8 with: scripts-ref: v0.11.8 dry-run: ${{ github.event_name == 'pull_request' }} diff --git a/.github/workflows/ci-check-md-links.yml b/.github/workflows/ci-check-md-links.yml index d0dc889..af5378c 100644 --- a/.github/workflows/ci-check-md-links.yml +++ b/.github/workflows/ci-check-md-links.yml @@ -14,7 +14,7 @@ on: jobs: check-md-links: - uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.8 with: config-file: ".github/markdown-links-config.json" base-branch: "master" diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index 32cd82f..2ccaadd 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -8,7 +8,7 @@ on: jobs: check: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.8 with: # skip azure due to the wrong schema file by MSFT # https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5e65de..24fc405 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -74,7 +74,7 @@ repos: hooks: # try to fix what is possible - id: ruff - args: ["--fix", "--unsafe-fixes"] + args: ["--fix"] # perform formatting updates - id: ruff-format # validate if all is fine with preview mode diff --git a/pyproject.toml b/pyproject.toml index 48439be..da4cd7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ ignore-words-list = "te, compiletime" [tool.ruff] line-length = 120 -target-version = "py39" +target-version = "py38" # Exclude a variety of commonly ignored directories. exclude = [ ".git", diff --git a/setup.py b/setup.py index 92f0265..bfc329b 100755 --- a/setup.py +++ b/setup.py @@ -45,10 +45,9 @@ import logging import os import tempfile -from collections.abc import Generator, Mapping from importlib.util import module_from_spec, spec_from_file_location from types import ModuleType -from typing import Optional +from typing import Generator, Mapping, Optional import setuptools import setuptools.command.egg_info diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index 2d3bb0e..09eab56 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any +from typing import Any, Dict from setuptools import find_namespace_packages @@ -26,7 +26,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: _ASSISTANT = _load_py_module(name="assistant", location=os.path.join(_PROJECT_ROOT, ".actions", "assistant.py")) -def _prepare_extras() -> dict[str, Any]: +def _prepare_extras() -> Dict[str, Any]: # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. # From remote, use like `pip install "lightning[dev, docs]"` @@ -63,7 +63,7 @@ def _prepare_extras() -> dict[str, Any]: return extras -def _setup_args() -> dict[str, Any]: +def _setup_args() -> Dict[str, Any]: about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) long_description = _ASSISTANT.load_readme_description( diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 2997d1a..0334210 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import List, Union import torch from typing_extensions import override @@ -45,7 +45,7 @@ def parse_devices(devices: Union[int, str]) -> int: @staticmethod @override - def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 5b8a4c2..4afc9be 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import lru_cache -from typing import Optional, Union +from typing import List, Optional, Union import torch from typing_extensions import override @@ -43,7 +43,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -51,7 +51,7 @@ def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: @staticmethod @override - def get_parallel_devices(devices: list[int]) -> list[torch.device]: + def get_parallel_devices(devices: List[int]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @@ -76,7 +76,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def find_usable_cuda_devices(num_devices: int = -1) -> list[int]: +def find_usable_cuda_devices(num_devices: int = -1) -> List[int]: """Returns a list of all available and usable CUDA GPU devices. A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function @@ -129,7 +129,7 @@ def find_usable_cuda_devices(num_devices: int = -1) -> list[int]: return available_devices -def _get_all_visible_cuda_devices() -> list[int]: +def _get_all_visible_cuda_devices() -> List[int]: """Returns a list of all visible CUDA GPU devices. Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index b535ba5..7549716 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -14,7 +14,7 @@ import os import platform from functools import lru_cache -from typing import Optional, Union +from typing import List, Optional, Union import torch from typing_extensions import override @@ -46,7 +46,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -54,7 +54,7 @@ def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: @staticmethod @override - def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: + def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None @@ -84,7 +84,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def _get_all_available_mps_gpus() -> list[int]: +def _get_all_available_mps_gpus() -> List[int]: """ Returns: A list of all available MPS GPUs diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 17d5233..1299b1e 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, List, Optional from typing_extensions import override @@ -68,7 +68,7 @@ def register( if name in self and not override: raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - data: dict[str, Any] = {} + data: Dict[str, Any] = {} data["description"] = description data["init_params"] = init_params @@ -107,7 +107,7 @@ def remove(self, name: str) -> None: """Removes the registered accelerator by name.""" self.pop(name) - def available_accelerators(self) -> list[str]: + def available_accelerators(self) -> List[str]: """Returns a list of registered accelerators.""" return list(self.keys()) diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index d438197..38d7380 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Union +from typing import Any, List, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -47,13 +47,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: + def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: """Accelerator device parsing logic.""" return _parse_tpu_devices(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: + def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_tpu_devices(devices) if isinstance(devices, int): @@ -102,27 +102,20 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No # PJRT support requires this minimum version _XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") _XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") -_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") def _using_pjrt() -> bool: - # `using_pjrt` is removed in torch_xla 2.5 - if _XLA_GREATER_EQUAL_2_5: - from torch_xla import runtime as xr - - return xr.device_type() is not None # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped. if _XLA_GREATER_EQUAL_2_1: from torch_xla import runtime as xr return xr.using_pjrt() - from torch_xla.experimental import pjrt return pjrt.using_pjrt() -def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: +def _parse_tpu_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: """Parses the TPU devices given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`. @@ -159,7 +152,7 @@ def _check_tpu_devices_valid(devices: object) -> None: ) -def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]: +def _parse_tpu_devices_str(devices: str) -> Union[int, List[int]]: devices = devices.strip() try: return int(devices) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 5f18884..7c81afa 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -17,7 +17,7 @@ import subprocess import sys from argparse import Namespace -from typing import Any, Optional +from typing import Any, List, Optional import torch from lightning_utilities.core.imports import RequirementCache @@ -39,7 +39,7 @@ _SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") -def _get_supported_strategies() -> list[str]: +def _get_supported_strategies() -> List[str]: """Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the CLI or ones that require further configuration by the user.""" available_strategies = STRATEGY_REGISTRY.available_strategies() @@ -221,7 +221,7 @@ def _get_num_processes(accelerator: str, devices: str) -> int: return len(parsed_devices) if parsed_devices is not None else 0 -def _torchrun_launch(args: Namespace, script_args: list[str]) -> None: +def _torchrun_launch(args: Namespace, script_args: List[str]) -> None: """This will invoke `torchrun` programmatically to launch the given script in new processes.""" import torch.distributed.run as torchrun @@ -242,7 +242,7 @@ def _torchrun_launch(args: Namespace, script_args: list[str]) -> None: torchrun.main(torchrun_args) -def main(args: Namespace, script_args: Optional[list[str]] = None) -> None: +def main(args: Namespace, script_args: Optional[List[str]] = None) -> None: _set_env_variables(args) _torchrun_launch(args, script_args or []) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 9161d5f..9fb6625 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -13,7 +13,7 @@ # limitations under the License. import os from collections import Counter -from typing import Any, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union, cast import torch from typing_extensions import get_args @@ -99,10 +99,10 @@ def __init__( self, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[list[int], str, int] = "auto", + devices: Union[List[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, ) -> None: # These arguments can be set through environment variables set by the CLI accelerator = self._argument_from_env("accelerator", accelerator, default="auto") @@ -124,7 +124,7 @@ def __init__( self._precision_input: _PRECISION_INPUT_STR = "32-true" self._precision_instance: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: list[Union[int, torch.device, str]] = [] + self._parallel_devices: List[Union[int, torch.device, str]] = [] self.checkpoint_io: Optional[CheckpointIO] = None self._check_config_and_set_final_flags( @@ -165,7 +165,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], ) -> None: """This method checks: @@ -224,7 +224,7 @@ def _check_config_and_set_final_flags( precision_input = _convert_precision_to_unified_args(precision) if plugins: - plugins_flags_types: dict[str, int] = Counter() + plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Precision): self._precision_instance = plugin @@ -295,7 +295,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 058e5e7..0ff5b04 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -13,14 +13,20 @@ # limitations under the License. import inspect import os -from collections.abc import Generator, Mapping, Sequence -from contextlib import AbstractContextManager, contextmanager, nullcontext +from contextlib import contextmanager, nullcontext from functools import partial from pathlib import Path from typing import ( Any, Callable, + ContextManager, + Dict, + Generator, + List, + Mapping, Optional, + Sequence, + Tuple, Union, cast, overload, @@ -112,12 +118,12 @@ def __init__( *, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[list[int], str, int] = "auto", + devices: Union[List[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, - callbacks: Optional[Union[list[Any], Any]] = None, - loggers: Optional[Union[Logger, list[Logger]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + callbacks: Optional[Union[List[Any], Any]] = None, + loggers: Optional[Union[Logger, List[Logger]]] = None, ) -> None: self._connector = _Connector( accelerator=accelerator, @@ -186,7 +192,7 @@ def is_global_zero(self) -> bool: return self._strategy.is_global_zero @property - def loggers(self) -> list[Logger]: + def loggers(self) -> List[Logger]: """Returns all loggers passed to Fabric.""" return self._loggers @@ -320,7 +326,7 @@ def setup_module( self._models_setup += 1 return module - def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]: + def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tuple[_FabricOptimizer, ...]]: r"""Set up one or more optimizers for accelerated training. Some strategies do not allow setting up model and optimizer independently. For them, you should call @@ -343,7 +349,7 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tu def setup_dataloaders( self, *dataloaders: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True - ) -> Union[DataLoader, list[DataLoader]]: + ) -> Union[DataLoader, List[DataLoader]]: r"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -483,7 +489,7 @@ def clip_gradients( ) raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!") - def autocast(self) -> AbstractContextManager: + def autocast(self) -> ContextManager: """A context manager to automatically convert operations for the chosen precision. Use this only if the `forward` method of your model does not cover all operations you wish to run with the @@ -558,8 +564,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return self._strategy.broadcast(obj, src=src) def all_gather( - self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, dict, list, tuple]: + self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[Tensor, Dict, List, Tuple]: """Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -583,10 +589,10 @@ def all_gather( def all_reduce( self, - data: Union[Tensor, dict, list, tuple], + data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Union[Tensor, dict, list, tuple]: + ) -> Union[Tensor, Dict, List, Tuple]: """Reduce tensors or collections of tensors from multiple processes. The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. @@ -633,7 +639,7 @@ def rank_zero_first(self, local: bool = False) -> Generator: if rank == 0: barrier() - def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> AbstractContextManager: + def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager: r"""Skip gradient synchronization during backward to avoid redundant communication overhead. Use this context manager when performing gradient accumulation to speed up training with multiple devices. @@ -675,7 +681,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Abstr forward_module, _ = _unwrap_compiled(module._forward_module) return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled) - def sharded_model(self) -> AbstractContextManager: + def sharded_model(self) -> ContextManager: r"""Instantiate a model under this context manager to prepare it for model-parallel sharding. .. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead. @@ -687,12 +693,12 @@ def sharded_model(self) -> AbstractContextManager: return self.strategy.module_sharded_context() return nullcontext() - def init_tensor(self) -> AbstractContextManager: + def init_tensor(self) -> ContextManager: """Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.""" return self._strategy.tensor_init_context() - def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def init_module(self, empty_init: Optional[bool] = None) -> ContextManager: """Instantiate the model and its parameters under this context manager to reduce peak memory usage. The parameters get created on the device and with the right data type right away without wasting memory being @@ -710,8 +716,8 @@ def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManag def save( self, path: Union[str, Path], - state: dict[str, Union[nn.Module, Optimizer, Any]], - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + state: Dict[str, Union[nn.Module, Optimizer, Any]], + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: r"""Save checkpoint contents to a file. @@ -744,9 +750,9 @@ def save( def load( self, path: Union[str, Path], - state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None, + state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None, strict: bool = True, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) How and which processes load gets determined by the `strategy`. @@ -927,7 +933,7 @@ def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any: with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler): return to_run(*args, **kwargs) - def _move_model_to_device(self, model: nn.Module, optimizers: list[Optimizer]) -> nn.Module: + def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: try: initial_name, initial_param = next(model.named_parameters()) except StopIteration: @@ -1055,7 +1061,7 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") @staticmethod - def _configure_callbacks(callbacks: Optional[Union[list[Any], Any]]) -> list[Any]: + def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]: callbacks = callbacks if callbacks is not None else [] callbacks = callbacks if isinstance(callbacks, list) else [callbacks] callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory")) diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index dd7dfc6..4dbb56f 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -16,7 +16,7 @@ import logging import os from argparse import Namespace -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Set, Union from torch import Tensor from typing_extensions import override @@ -138,13 +138,13 @@ def experiment(self) -> "_ExperimentWriter": @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: raise NotImplementedError("The `CSVLogger` does not yet support logging hyperparameters.") @override @rank_zero_only def log_metrics( # type: ignore[override] - self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None + self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None ) -> None: metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) if step is None: @@ -200,8 +200,8 @@ class _ExperimentWriter: NAME_METRICS_FILE = "metrics.csv" def __init__(self, log_dir: str) -> None: - self.metrics: list[dict[str, float]] = [] - self.metrics_keys: list[str] = [] + self.metrics: List[Dict[str, float]] = [] + self.metrics_keys: List[str] = [] self._fs = get_filesystem(log_dir) self.log_dir = log_dir @@ -210,7 +210,7 @@ def __init__(self, log_dir: str) -> None: self._check_log_dir_exists() self._fs.makedirs(self.log_dir, exist_ok=True) - def log_metrics(self, metrics_dict: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: """Record metrics.""" def _handle_value(value: Union[Tensor, Any]) -> Any: @@ -246,7 +246,7 @@ def save(self) -> None: self.metrics = [] # reset - def _record_new_keys(self) -> set[str]: + def _record_new_keys(self) -> Set[str]: """Records new keys that have not been logged before.""" current_keys = set().union(*self.metrics) new_keys = current_keys - set(self.metrics_keys) @@ -254,7 +254,7 @@ def _record_new_keys(self) -> set[str]: self.metrics_keys.sort() return new_keys - def _rewrite_with_new_header(self, fieldnames: list[str]) -> None: + def _rewrite_with_new_header(self, fieldnames: List[str]) -> None: with self._fs.open(self.metrics_file_path, "r", newline="") as file: metrics = list(csv.DictReader(file)) diff --git a/src/lightning/fabric/loggers/logger.py b/src/lightning/fabric/loggers/logger.py index 39a9fa0..5647ab9 100644 --- a/src/lightning/fabric/loggers/logger.py +++ b/src/lightning/fabric/loggers/logger.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union from torch import Tensor from torch.nn import Module @@ -55,7 +55,7 @@ def group_separator(self) -> str: return "/" @abstractmethod - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: """Records metrics. This method logs metrics as soon as it received them. Args: @@ -66,7 +66,7 @@ def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> pass @abstractmethod - def log_hyperparams(self, params: Union[dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: """Record hyperparameters. Args: diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 14bc3d6..685c832 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -14,8 +14,7 @@ import os from argparse import Namespace -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -220,7 +219,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) @override @rank_zero_only def log_hyperparams( - self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None + self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to @@ -319,12 +318,12 @@ def _get_next_version(self) -> int: return max(existing_versions) + 1 @staticmethod - def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: + def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: params = _utils_sanitize_params(params) # logging of arrays with dimension > 1 is not supported, sanitize as string return {k: str(v) if hasattr(v, "ndim") and v.ndim > 1 else v for k, v in params.items()} - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state["_experiment"] = None return state diff --git a/src/lightning/fabric/plugins/collectives/collective.py b/src/lightning/fabric/plugins/collectives/collective.py index 9408fd8..3b336b5 100644 --- a/src/lightning/fabric/plugins/collectives/collective.py +++ b/src/lightning/fabric/plugins/collectives/collective.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, List, Optional from torch import Tensor from typing_extensions import Self @@ -47,19 +47,19 @@ def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ... def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ... @abstractmethod - def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: ... + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: ... @abstractmethod - def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: ... + def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: ... @abstractmethod - def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: ... + def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: ... @abstractmethod - def reduce_scatter(self, output: Tensor, input_list: list[Tensor], op: str) -> Tensor: ... + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor: ... @abstractmethod - def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: ... + def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: ... @abstractmethod def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... @@ -68,7 +68,7 @@ def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: ... @abstractmethod - def barrier(self, device_ids: Optional[list[int]] = None) -> None: ... + def barrier(self, device_ids: Optional[List[int]] = None) -> None: ... @classmethod @abstractmethod diff --git a/src/lightning/fabric/plugins/collectives/single_device.py b/src/lightning/fabric/plugins/collectives/single_device.py index 7337871..9b635f6 100644 --- a/src/lightning/fabric/plugins/collectives/single_device.py +++ b/src/lightning/fabric/plugins/collectives/single_device.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List from torch import Tensor from typing_extensions import override @@ -37,31 +37,31 @@ def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor @override - def all_gather(self, tensor_list: list[Tensor], tensor: Tensor, **__: Any) -> list[Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]: return [tensor] @override - def gather(self, tensor: Tensor, *_: Any, **__: Any) -> list[Tensor]: + def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]: return [tensor] @override def scatter( self, tensor: Tensor, - scatter_list: list[Tensor], + scatter_list: List[Tensor], *_: Any, **__: Any, ) -> Tensor: return scatter_list[0] @override - def reduce_scatter(self, output: Tensor, input_list: list[Tensor], *_: Any, **__: Any) -> Tensor: + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor: return input_list[0] @override def all_to_all( - self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor], *_: Any, **__: Any - ) -> list[Tensor]: + self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any + ) -> List[Tensor]: return input_tensor_list @override diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 81e15a3..0dea303 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -1,6 +1,6 @@ import datetime import os -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch import torch.distributed as dist @@ -66,30 +66,30 @@ def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = return tensor @override - def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: dist.all_gather(tensor_list, tensor, group=self.group) return tensor_list @override - def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: + def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: dist.gather(tensor, gather_list, dst, group=self.group) return gather_list @override - def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: + def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: dist.scatter(tensor, scatter_list, src, group=self.group) return tensor @override def reduce_scatter( - self, output: Tensor, input_list: list[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" + self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" ) -> Tensor: op = self._convert_to_native_op(op) dist.reduce_scatter(output, input_list, op=op, group=self.group) return output @override - def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: + def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group) return output_tensor_list @@ -102,28 +102,28 @@ def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tenso dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type] return tensor - def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]: + def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]: dist.all_gather_object(object_list, obj, group=self.group) return object_list def broadcast_object_list( - self, object_list: list[Any], src: int, device: Optional[torch.device] = None - ) -> list[Any]: + self, object_list: List[Any], src: int, device: Optional[torch.device] = None + ) -> List[Any]: dist.broadcast_object_list(object_list, src, group=self.group, device=device) return object_list - def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]: + def gather_object(self, obj: Any, object_gather_list: List[Any], dst: int = 0) -> List[Any]: dist.gather_object(obj, object_gather_list, dst, group=self.group) return object_gather_list def scatter_object_list( - self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0 - ) -> list[Any]: + self, scatter_object_output_list: List[Any], scatter_object_input_list: List[Any], src: int = 0 + ) -> List[Any]: dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) return scatter_object_output_list @override - def barrier(self, device_ids: Optional[list[int]] = None) -> None: + def barrier(self, device_ids: Optional[List[int]] = None) -> None: if self.group == dist.GroupMember.NON_GROUP_MEMBER: return dist.barrier(group=self.group, device_ids=device_ids) diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py index f0a07d6..6a23006 100644 --- a/src/lightning/fabric/plugins/environments/lsf.py +++ b/src/lightning/fabric/plugins/environments/lsf.py @@ -14,6 +14,7 @@ import logging import os import socket +from typing import Dict, List from typing_extensions import override @@ -143,14 +144,14 @@ def _get_node_rank(self) -> int: """ hosts = self._read_hosts() - count: dict[str, int] = {} + count: Dict[str, int] = {} for host in hosts: if host not in count: count[host] = len(count) return count[socket.gethostname()] @staticmethod - def _read_hosts() -> list[str]: + def _read_hosts() -> List[str]: """Read compute hosts that are a part of the compute job. LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes. diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 3a33dac..79fc9e8 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Dict, Optional from lightning.fabric.utilities.types import _PATH @@ -36,7 +36,7 @@ class CheckpointIO(ABC): """ @abstractmethod - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -47,7 +47,7 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio """ @abstractmethod - def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 90a5f62..02de1aa 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional from typing_extensions import override @@ -34,7 +34,7 @@ class TorchCheckpointIO(CheckpointIO): """ @override - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -60,7 +60,7 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. Args: diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py index 146fa2f..5c154d8 100644 --- a/src/lightning/fabric/plugins/io/xla.py +++ b/src/lightning/fabric/plugins/io/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Optional +from typing import Any, Dict, Optional import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -41,7 +41,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @override - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index d5fc1f0..c624e82 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager -from typing import Any, Literal, Optional +from typing import Any, ContextManager, Dict, Literal, Optional import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -60,7 +59,7 @@ def __init__( self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16 @override - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: return torch.autocast(self.device, dtype=self._desired_input_dtype) @override @@ -94,13 +93,13 @@ def optimizer_step( return step_output @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index ecb1d8a..3944154 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -16,11 +16,10 @@ import math import os import warnings -from collections import OrderedDict -from contextlib import AbstractContextManager, ExitStack +from contextlib import ExitStack from functools import partial from types import ModuleType -from typing import Any, Callable, Literal, Optional, cast +from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Tuple, Type, cast import torch from lightning_utilities import apply_to_collection @@ -71,7 +70,7 @@ def __init__( self, mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"], dtype: Optional[torch.dtype] = None, - ignore_modules: Optional[set[str]] = None, + ignore_modules: Optional[Set[str]] = None, ) -> None: _import_bitsandbytes() @@ -123,11 +122,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self.dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules raise RuntimeError( @@ -145,7 +144,7 @@ def module_init_context(self) -> AbstractContextManager: return stack @override - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: return _DtypeContextManager(self.dtype) @override @@ -176,7 +175,7 @@ def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _In def _replace_param( - param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None + param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[Tuple] = None ) -> torch.nn.Parameter: bnb = _import_bitsandbytes() @@ -419,7 +418,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: return bnb -def _convert_layers(module: torch.nn.Module, linear_cls: type, ignore_modules: set[str], prefix: str = "") -> None: +def _convert_layers(module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = "") -> None: for name, child in module.named_children(): fullname = f"{prefix}.{name}" if prefix else name if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 5260950..2fcaa38 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Literal +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, ContextManager, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -68,13 +68,13 @@ def convert_module(self, module: Module) -> Module: return module @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 9aa0365..0a85749 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager -from typing import Any, Literal +from typing import Any, ContextManager, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -34,15 +33,15 @@ def convert_module(self, module: Module) -> Module: return module.double() @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(torch.double) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return self.tensor_init_context() @override - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 0b78ad7..179fc21 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, ContextManager, Dict, Literal, Optional import torch from lightning_utilities import apply_to_collection @@ -101,15 +100,15 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": ) @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return self.tensor_init_context() @@ -151,12 +150,12 @@ def unscale_gradients(self, optimizer: Optimizer) -> None: scaler.unscale_(optimizer) @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py index fcb28ad..32ca7da 100644 --- a/src/lightning/fabric/plugins/precision/half.py +++ b/src/lightning/fabric/plugins/precision/half.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager -from typing import Any, Literal +from typing import Any, ContextManager, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -43,15 +42,15 @@ def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return self.tensor_init_context() @override - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index 1dfab2a..fbff54f 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager, nullcontext -from typing import Any, Literal, Optional, Union +from contextlib import nullcontext +from typing import Any, ContextManager, Dict, Literal, Optional, Union from torch import Tensor from torch.nn import Module @@ -53,11 +53,11 @@ def convert_module(self, module: Module) -> Module: """ return module - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: """Controls how tensors get created (device, dtype).""" return nullcontext() - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: """Instantiate module parameters or tensors in the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. @@ -65,7 +65,7 @@ def module_init_context(self) -> AbstractContextManager: """ return nullcontext() - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" return nullcontext() @@ -135,7 +135,7 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS: def unscale_gradients(self, optimizer: Optimizer) -> None: return - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: """Called when saving a checkpoint, implement to generate precision plugin state_dict. Returns: @@ -144,7 +144,7 @@ def state_dict(self) -> dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict. diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index c3ef84a..cb5296b 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections.abc import Mapping -from contextlib import AbstractContextManager, ExitStack -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -107,11 +106,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self.weights_dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: dtype_ctx = self.tensor_init_context() stack = ExitStack() if self.replace_layers: @@ -126,7 +125,7 @@ def module_init_context(self) -> AbstractContextManager: return stack @override - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: dtype_ctx = _DtypeContextManager(self.weights_dtype) fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) import transformer_engine.pytorch as te diff --git a/src/lightning/fabric/plugins/precision/utils.py b/src/lightning/fabric/plugins/precision/utils.py index 8362384..887dbc9 100644 --- a/src/lightning/fabric/plugins/precision/utils.py +++ b/src/lightning/fabric/plugins/precision/utils.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Mapping -from typing import Any, Union +from typing import Any, Mapping, Type, Union import torch from torch import Tensor @@ -44,7 +43,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: class _ClassReplacementContextManager: """A context manager to monkeypatch classes.""" - def __init__(self, mapping: Mapping[str, type]) -> None: + def __init__(self, mapping: Mapping[str, Type]) -> None: self._mapping = mapping self._originals = {} self._modules = {} diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index ce47e4e..c387806 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager, nullcontext +from contextlib import nullcontext from datetime import timedelta -from typing import Any, Literal, Optional, Union +from typing import Any, ContextManager, Dict, List, Literal, Optional, Union import torch import torch.distributed @@ -55,7 +55,7 @@ class DDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, @@ -99,7 +99,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> dict[str, Any]: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -171,14 +171,14 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj[0] @override - def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: if isinstance(module, DistributedDataParallel): module = module.module return super().get_module_state_dict(module) @override def load_module_state_dict( - self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: if isinstance(module, DistributedDataParallel): module = module.module @@ -225,13 +225,13 @@ def _set_world_ranks(self) -> None: # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - def _determine_ddp_device_ids(self) -> Optional[list[int]]: + def _determine_ddp_device_ids(self) -> Optional[List[int]]: return None if self.root_device.type == "cpu" else [self.root_device.index] class _DDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 03d90cd..e71b8e2 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,11 +16,10 @@ import logging import os import platform -from collections.abc import Mapping -from contextlib import AbstractContextManager, ExitStack +from contextlib import ExitStack from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Mapping, Optional, Tuple, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -81,9 +80,9 @@ def __init__( reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, logging_batch_size_per_gpu: Optional[int] = None, - config: Optional[Union[_PATH, dict[str, Any]]] = None, + config: Optional[Union[_PATH, Dict[str, Any]]] = None, logging_level: int = logging.WARN, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, initial_scale_power: int = 16, @@ -303,7 +302,7 @@ def zero_stage_3(self) -> bool: @property @override - def distributed_sampler_kwargs(self) -> dict[str, int]: + def distributed_sampler_kwargs(self) -> Dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @property @@ -312,8 +311,8 @@ def model(self) -> "DeepSpeedEngine": @override def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple["DeepSpeedEngine", list[Optimizer]]: + self, module: Module, optimizers: List[Optimizer] + ) -> Tuple["DeepSpeedEngine", List[Optimizer]]: """Set up a model and multiple optimizers together. Currently, only a single optimizer is supported. @@ -353,7 +352,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: raise NotImplementedError(self._err_msg_joint_setup_required()) @override - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: if self.zero_stage_3 and empty_init is False: raise NotImplementedError( f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." @@ -366,7 +365,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont return stack @override - def module_sharded_context(self) -> AbstractContextManager: + def module_sharded_context(self) -> ContextManager: # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here. @@ -383,9 +382,9 @@ def module_sharded_context(self) -> AbstractContextManager: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state in a checkpoint directory. @@ -448,9 +447,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. Args: @@ -596,7 +595,7 @@ def _initialize_engine( self, model: Module, optimizer: Optional[Optimizer] = None, - ) -> tuple["DeepSpeedEngine", Optimizer]: + ) -> Tuple["DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. This calls ``deepspeed.initialize`` internally. @@ -715,7 +714,7 @@ def _create_default_config( overlap_events: bool, thread_count: int, **zero_kwargs: Any, - ) -> dict: + ) -> Dict: cfg = { "activation_checkpointing": { "partition_activations": partition_activations, @@ -770,9 +769,9 @@ def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None: import deepspeed def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: list[str] = [] - unexpected_keys: list[str] = [] - error_msgs: list[str] = [] + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it @@ -803,7 +802,7 @@ def load(module: torch.nn.Module, prefix: str = "") -> None: load(module, prefix="") - def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: + def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -818,14 +817,14 @@ def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Option return config -def _get_deepspeed_engines_from_state(state: dict[str, Any]) -> list["DeepSpeedEngine"]: +def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["DeepSpeedEngine"]: from deepspeed import DeepSpeedEngine modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module))) return [engine for engine in modules if isinstance(engine, DeepSpeedEngine)] -def _validate_state_keys(state: dict[str, Any]) -> None: +def _validate_state_keys(state: Dict[str, Any]) -> None: # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for # colliding keys from the user. We explicitly check it here: deepspeed_internal_keys = { @@ -852,7 +851,7 @@ def _validate_state_keys(state: dict[str, Any]) -> None: ) -def _validate_device_index_selection(parallel_devices: list[torch.device]) -> None: +def _validate_device_index_selection(parallel_devices: List[torch.device]) -> None: selected_device_indices = [device.index for device in parallel_devices] expected_device_indices = list(range(len(parallel_devices))) if selected_device_indices != expected_device_indices: @@ -904,7 +903,7 @@ def _validate_checkpoint_directory(path: _PATH) -> None: def _format_precision_config( - config: dict[str, Any], + config: Dict[str, Any], precision: str, loss_scale: float, loss_scale_window: int, diff --git a/src/lightning/fabric/strategies/dp.py b/src/lightning/fabric/strategies/dp.py index f407040..2fed307 100644 --- a/src/lightning/fabric/strategies/dp.py +++ b/src/lightning/fabric/strategies/dp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor @@ -35,7 +35,7 @@ class DataParallelStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, ): @@ -95,14 +95,14 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: return decision @override - def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: if isinstance(module, DataParallel): module = module.module return super().get_module_state_dict(module) @override def load_module_state_dict( - self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: if isinstance(module, DataParallel): module = module.module diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 9dd5b2c..e7fdd29 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -13,8 +13,7 @@ # limitations under the License. import shutil import warnings -from collections.abc import Generator -from contextlib import AbstractContextManager, ExitStack, nullcontext +from contextlib import ExitStack, nullcontext from datetime import timedelta from functools import partial from pathlib import Path @@ -22,8 +21,15 @@ TYPE_CHECKING, Any, Callable, + ContextManager, + Dict, + Generator, + List, Literal, Optional, + Set, + Tuple, + Type, Union, ) @@ -72,7 +78,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy - _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] + _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] @@ -137,7 +143,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, @@ -145,11 +151,11 @@ def __init__( cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "sharded", - device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, + device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -210,7 +216,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> dict[str, Any]: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -261,8 +267,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple[Module, list[Optimizer]]: + self, module: Module, optimizers: List[Optimizer] + ) -> Tuple[Module, List[Optimizer]]: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer.""" use_orig_params = self._fsdp_kwargs.get("use_orig_params") @@ -334,7 +340,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -348,7 +354,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont return stack @override - def module_sharded_context(self) -> AbstractContextManager: + def module_sharded_context(self) -> ContextManager: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.wrap import enable_wrap @@ -413,9 +419,9 @@ def clip_gradients_norm( def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -467,8 +473,8 @@ def save_checkpoint( # replace the modules and optimizer objects in the state with their local state dict # and separate the user's metadata - converted_state: dict[str, Any] = {} - metadata: dict[str, Any] = {} + converted_state: Dict[str, Any] = {} + metadata: Dict[str, Any] = {} with state_dict_ctx: for key, obj in state.items(): converted: Any @@ -493,7 +499,7 @@ def save_checkpoint( shutil.rmtree(path) state_dict_ctx = _get_full_state_dict_context(module, world_size=self.world_size) - full_state: dict[str, Any] = {} + full_state: Dict[str, Any] = {} with state_dict_ctx: for key, obj in state.items(): if isinstance(obj, Module): @@ -513,9 +519,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: raise ValueError( @@ -677,9 +683,9 @@ def _set_world_ranks(self) -> None: def _activation_checkpointing_kwargs( - activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]], + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]], activation_checkpointing_policy: Optional["_POLICY"], -) -> dict: +) -> Dict: if activation_checkpointing is None and activation_checkpointing_policy is None: return {} if activation_checkpointing is not None and activation_checkpointing_policy is not None: @@ -701,7 +707,7 @@ def _activation_checkpointing_kwargs( return {"auto_wrap_policy": activation_checkpointing_policy} -def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: +def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: if policy is None: return kwargs if isinstance(policy, set): @@ -713,7 +719,7 @@ def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: return kwargs -def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: dict) -> None: +def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: Dict) -> None: if not activation_checkpointing_kwargs: return @@ -739,7 +745,7 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa class _FSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper.""" if not enabled: @@ -762,7 +768,7 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) -def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY", kwargs: dict) -> "ShardingStrategy": +def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY", kwargs: Dict) -> "ShardingStrategy": from torch.distributed.fsdp import ShardingStrategy if kwargs.get("process_group") is not None and kwargs.get("device_mesh") is not None: @@ -852,7 +858,7 @@ def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) metric.to(device) # `.to()` is in-place -def _distributed_checkpoint_save(converted_state: dict[str, Any], path: Path) -> None: +def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> None: if _TORCH_GREATER_EQUAL_2_3: from torch.distributed.checkpoint import save @@ -871,7 +877,7 @@ def _distributed_checkpoint_save(converted_state: dict[str, Any], path: Path) -> save(converted_state, writer) -def _distributed_checkpoint_load(module_state: dict[str, Any], path: Path) -> None: +def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None: if _TORCH_GREATER_EQUAL_2_3: from torch.distributed.checkpoint import load diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index d9b96dc..14a063f 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from multiprocessing.queues import SimpleQueue from textwrap import dedent -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional import torch import torch.backends.cudnn @@ -167,7 +167,7 @@ class _GlobalStateSnapshot: use_deterministic_algorithms: bool use_deterministic_algorithms_warn_only: bool cudnn_benchmark: bool - rng_states: dict[str, Any] + rng_states: Dict[str, Any] @classmethod def capture(cls) -> "_GlobalStateSnapshot": diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py index a28fe97..63ae8b0 100644 --- a/src/lightning/fabric/strategies/launchers/subprocess_script.py +++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py @@ -18,8 +18,7 @@ import sys import threading import time -from collections.abc import Sequence -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional, Sequence, Tuple from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -81,7 +80,7 @@ def __init__( self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.procs: list[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher + self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher @property @override @@ -163,7 +162,7 @@ def _basic_subprocess_cmd() -> Sequence[str]: return [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] -def _hydra_subprocess_cmd(local_rank: int) -> tuple[Sequence[str], str]: +def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path @@ -184,13 +183,13 @@ def _hydra_subprocess_cmd(local_rank: int) -> tuple[Sequence[str], str]: return command, cwd -def _launch_process_observer(child_processes: list[subprocess.Popen]) -> None: +def _launch_process_observer(child_processes: List[subprocess.Popen]) -> None: """Launches a thread that runs along the main process and monitors the health of all processes.""" _ChildProcessObserver(child_processes=child_processes, main_pid=os.getpid()).start() class _ChildProcessObserver(threading.Thread): - def __init__(self, main_pid: int, child_processes: list[subprocess.Popen], sleep_period: int = 5) -> None: + def __init__(self, main_pid: int, child_processes: List[subprocess.Popen], sleep_period: int = 5) -> None: super().__init__(daemon=True, name="child-process-observer") # thread stops if the main process exits self._main_pid = main_pid self._child_processes = child_processes diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index ad1fc19..86b93d3 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -13,11 +13,10 @@ # limitations under the License. import itertools import shutil -from collections.abc import Generator -from contextlib import AbstractContextManager, ExitStack +from contextlib import ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -145,7 +144,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> dict[str, Any]: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: assert self.device_mesh is not None data_parallel_mesh = self.device_mesh["data_parallel"] return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @@ -195,7 +194,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: precision_init_ctx = self.precision.module_init_context() stack = ExitStack() if empty_init: @@ -235,9 +234,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -273,9 +272,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: raise ValueError( @@ -319,12 +318,12 @@ def _set_world_ranks(self) -> None: class _ParallelBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks gradient synchronization inside the FSDP2 modules.""" return _FSDPNoSync(module=module, enabled=enabled) -class _FSDPNoSync(AbstractContextManager): +class _FSDPNoSync(ContextManager): def __init__(self, module: Module, enabled: bool) -> None: self._module = module self._enabled = enabled @@ -345,10 +344,10 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def _save_checkpoint( path: Path, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], full_state_dict: bool, rank: int, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: if path.is_dir() and full_state_dict and not _is_sharded_checkpoint(path): raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") @@ -374,8 +373,8 @@ def _save_checkpoint( # replace the modules and optimizer objects in the state with their local state dict # and separate the user's metadata - converted_state: dict[str, Any] = {} - metadata: dict[str, Any] = {} + converted_state: Dict[str, Any] = {} + metadata: Dict[str, Any] = {} for key, obj in state.items(): converted: Any if isinstance(obj, Module): @@ -406,10 +405,10 @@ def _save_checkpoint( def _load_checkpoint( path: Path, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], strict: bool = True, optimizer_states_from_list: bool = False, -) -> dict[str, Any]: +) -> Dict[str, Any]: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_model_state_dict, @@ -538,7 +537,7 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int def _load_raw_module_state( - state_dict: dict[str, Any], module: Module, world_size: int = 1, strict: bool = True + state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True ) -> None: """Loads the state dict into the module by gathering all weights first and then and writing back to each shard.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -584,7 +583,7 @@ def _named_parameters_and_buffers_to_load(module: Module) -> Generator: yield param_name, param -def _rekey_optimizer_state_if_needed(optimizer_state_dict: dict[str, Any], module: Module) -> dict[str, Any]: +def _rekey_optimizer_state_if_needed(optimizer_state_dict: Dict[str, Any], module: Module) -> Dict[str, Any]: """Handles the case where the optimizer state is saved from a normal optimizer and converts the keys to parameter names.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/src/lightning/fabric/strategies/parallel.py b/src/lightning/fabric/strategies/parallel.py index d9bc1a0..a12a061 100644 --- a/src/lightning/fabric/strategies/parallel.py +++ b/src/lightning/fabric/strategies/parallel.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any, Optional +from typing import Any, Dict, List, Optional import torch from torch import Tensor @@ -33,7 +33,7 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, @@ -64,15 +64,15 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[list[torch.device]]: + def parallel_devices(self) -> Optional[List[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: + def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: """Arguments for the ``DistributedSampler``. If this method is not defined, or it returns ``None``, then the ``DistributedSampler`` will not be used. diff --git a/src/lightning/fabric/strategies/registry.py b/src/lightning/fabric/strategies/registry.py index d237646..d789958 100644 --- a/src/lightning/fabric/strategies/registry.py +++ b/src/lightning/fabric/strategies/registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, List, Optional from typing_extensions import override @@ -65,7 +65,7 @@ def register( if name in self and not override: raise ValueError(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - data: dict[str, Any] = {} + data: Dict[str, Any] = {} data["description"] = description if description is not None else "" data["init_params"] = init_params @@ -104,7 +104,7 @@ def remove(self, name: str) -> None: """Removes the registered strategy by name.""" self.pop(name) - def available_strategies(self) -> list: + def available_strategies(self) -> List: """Returns a list of registered strategies.""" return list(self.keys()) diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 4daad9b..6bfed6a 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -13,9 +13,8 @@ # limitations under the License. import logging from abc import ABC, abstractmethod -from collections.abc import Iterable -from contextlib import AbstractContextManager, ExitStack -from typing import Any, Callable, Optional, TypeVar, Union +from contextlib import ExitStack +from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Tuple, TypeVar, Union import torch from torch import Tensor @@ -118,7 +117,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: """ return dataloader - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: """Controls how tensors get created (device, dtype).""" precision_init_ctx = self.precision.tensor_init_context() stack = ExitStack() @@ -126,7 +125,7 @@ def tensor_init_context(self) -> AbstractContextManager: stack.enter_context(precision_init_ctx) return stack - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: """A context manager wrapping the model instantiation. Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other @@ -145,8 +144,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont return stack def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple[Module, list[Optimizer]]: + self, module: Module, optimizers: List[Optimizer] + ) -> Tuple[Module, List[Optimizer]]: """Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -257,9 +256,9 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. @@ -277,17 +276,17 @@ def save_checkpoint( if self.is_global_zero: self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options) - def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: """Returns model state.""" return module.state_dict() def load_module_state_dict( - self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: """Loads the given state into the model.""" module.load_state_dict(state_dict, strict=strict) - def get_optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: + def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. @@ -305,9 +304,9 @@ def get_optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. Args: @@ -395,9 +394,9 @@ def _err_msg_joint_setup_required(self) -> str: ) def _convert_stateful_objects_in_state( - self, state: dict[str, Union[Module, Optimizer, Any]], filter: dict[str, Callable[[str, Any], bool]] - ) -> dict[str, Any]: - converted_state: dict[str, Any] = {} + self, state: Dict[str, Union[Module, Optimizer, Any]], filter: Dict[str, Callable[[str, Any], bool]] + ) -> Dict[str, Any]: + converted_state: Dict[str, Any] = {} for key, obj in state.items(): # convert the state if isinstance(obj, Module): @@ -422,7 +421,7 @@ class _BackwardSyncControl(ABC): """ @abstractmethod - def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks the synchronization of gradients during the backward pass. This is a context manager. It is only effective if it wraps a call to `.backward()`. @@ -434,7 +433,7 @@ class _Sharded(ABC): """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters.""" @abstractmethod - def module_sharded_context(self) -> AbstractContextManager: + def module_sharded_context(self) -> ContextManager: """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of parameters on creation. @@ -455,7 +454,7 @@ def _validate_keys_for_strict_loading( def _apply_filter( - key: str, filter: dict[str, Callable[[str, Any], bool]], source_dict: object, target_dict: dict[str, Any] + key: str, filter: Dict[str, Callable[[str, Any], bool]], source_dict: object, target_dict: Dict[str, Any] ) -> None: # filter out if necessary if key in filter and isinstance(source_dict, dict): diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 3b2e10e..28d6555 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch from torch import Tensor @@ -43,7 +43,7 @@ class XLAStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, sync_module_states: bool = True, @@ -276,9 +276,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 935ef72..e4c080d 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from contextlib import AbstractContextManager, ExitStack, nullcontext +from contextlib import ExitStack, nullcontext from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Set, Tuple, Type, Union import torch from torch import Tensor @@ -46,7 +46,7 @@ if TYPE_CHECKING: from torch_xla.distributed.parallel_loader import MpDeviceLoader -_POLICY_SET = set[type[Module]] +_POLICY_SET = Set[Type[Module]] _POLICY = Union[_POLICY_SET, Callable[[Module, bool, int], bool]] @@ -83,7 +83,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, auto_wrap_policy: Optional[_POLICY] = None, @@ -196,8 +196,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: list[Optimizer] - ) -> tuple[Module, list[Optimizer]]: + self, module: Module, optimizers: List[Optimizer] + ) -> Tuple[Module, List[Optimizer]]: """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.""" raise NotImplementedError( f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." @@ -225,7 +225,7 @@ def setup_module(self, module: Module) -> Module: def module_to_device(self, module: Module) -> None: pass - def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -235,7 +235,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont return stack @override - def module_sharded_context(self) -> AbstractContextManager: + def module_sharded_context(self) -> ContextManager: return nullcontext() @override @@ -408,9 +408,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state in the provided checkpoint directory. @@ -483,13 +483,13 @@ def save_checkpoint( def _save_checkpoint_shard( self, path: Path, - state: dict[str, Union[Module, Optimizer, Any]], + state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any], - filter: Optional[dict[str, Callable[[str, Any], bool]]], + filter: Optional[Dict[str, Callable[[str, Any], bool]]], ) -> None: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - converted_state: dict[str, Any] = {} + converted_state: Dict[str, Any] = {} for key, obj in state.items(): # convert the state if isinstance(obj, Module) and isinstance(obj, XLAFSDP): @@ -512,9 +512,9 @@ def _save_checkpoint_shard( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> dict[str, Any]: + ) -> Dict[str, Any]: """Given a folder, load the contents from a checkpoint and restore the state of the given objects. The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a @@ -617,7 +617,7 @@ def load_checkpoint( def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("xla_fsdp", cls, description=cls.__name__) - def _parse_fsdp_kwargs(self) -> dict: + def _parse_fsdp_kwargs(self) -> Dict: # this needs to be delayed because `self.precision` isn't available at init kwargs = self._fsdp_kwargs.copy() precision = self.precision @@ -629,7 +629,7 @@ def _parse_fsdp_kwargs(self) -> dict: return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs) -def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: +def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: if policy is None: return kwargs if isinstance(policy, set): @@ -649,7 +649,7 @@ def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, return XLAFSDP(module, *args, **kwargs) -def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict: +def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict) -> Dict: if not policy: return kwargs if "auto_wrapper_callable" in kwargs: @@ -668,7 +668,7 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict class _XLAFSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py index 35693a5..d43565f 100644 --- a/src/lightning/fabric/utilities/apply_func.py +++ b/src/lightning/fabric/utilities/apply_func.py @@ -15,7 +15,7 @@ from abc import ABC from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -34,7 +34,7 @@ def _from_numpy(value: "np.ndarray", device: _DEVICE) -> Tensor: return torch.from_numpy(value).to(device) -CONVERSION_DTYPES: list[tuple[Any, Callable[[Any, Any], Tensor]]] = [ +CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], Tensor]]] = [ # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group (bool, partial(torch.tensor, dtype=torch.uint8)), (int, partial(torch.tensor, dtype=torch.int)), diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 9d0a33a..7ecc9ee 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -16,7 +16,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Union +from typing import IO, Any, Dict, Union import fsspec import fsspec.utils @@ -69,7 +69,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: return fs -def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None: +def _atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index ea35d8c..1ec0edc 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -16,10 +16,9 @@ import inspect import os from collections import OrderedDict -from collections.abc import Generator, Iterable, Sized from contextlib import contextmanager from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union from lightning_utilities.core.inheritance import get_all_subclasses from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler @@ -80,7 +79,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable] def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], -) -> tuple[tuple[Any], dict[str, Any]]: +) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -173,7 +172,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], -) -> dict[str, Any]: +) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation.""" batch_sampler = getattr(dataloader, "batch_sampler") @@ -250,7 +249,7 @@ def _auto_add_worker_init_fn(dataloader: object, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[type] = None, **kwargs: Any) -> Any: +def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: constructor = type(orig_object) if explicit_cls is None else explicit_cls try: @@ -356,7 +355,7 @@ def wrapper(obj: Any, *args: Any) -> None: @contextmanager -def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: +def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. @@ -367,8 +366,8 @@ def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = # Check that __init__ belongs to the class # https://stackoverflow.com/a/5253424 if "__init__" in cls.__dict__: - cls.__old__init__ = cls.__init__ # type: ignore[misc] - cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) # type: ignore[misc] + cls.__old__init__ = cls.__init__ + cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls` @@ -390,11 +389,11 @@ def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = def _replace_value_in_saved_args( replace_key: str, replace_value: Any, - args: tuple[Any, ...], - kwargs: dict[str, Any], - default_kwargs: dict[str, Any], - arg_names: tuple[str, ...], -) -> tuple[bool, tuple[Any, ...], dict[str, Any]]: + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + default_kwargs: Dict[str, Any], + arg_names: Tuple[str, ...], +) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: """Tries to replace an argument value in a saved list of args and kwargs. Returns a tuple indicating success of the operation and modified saved args and kwargs @@ -421,7 +420,7 @@ def _set_sampler_epoch(dataloader: object, epoch: int) -> None: """ # cannot use a set because samplers might be unhashable: use a dict based on the id to drop duplicates - objects: dict[int, Any] = {} + objects: Dict[int, Any] = {} # check dataloader.sampler if (sampler := getattr(dataloader, "sampler", None)) is not None: objects[id(sampler)] = sampler @@ -459,7 +458,7 @@ def _num_cpus_available() -> int: return 1 if cpu_count is None else cpu_count -class AttributeDict(dict): +class AttributeDict(Dict): """A container to store state variables of your program. This is a drop-in replacement for a Python dictionary, with the additional functionality to access and modify keys diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index ff5a094..9f06dc5 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch.nn import Module @@ -20,7 +20,7 @@ class _DeviceDtypeModuleMixin(Module): - __jit_unused_properties__: list[str] = ["device", "dtype"] + __jit_unused_properties__: List[str] = ["device", "dtype"] def __init__(self) -> None: super().__init__() diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index ff5bebd..16965d9 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import MutableSequence -from typing import Optional, Union +from typing import List, MutableSequence, Optional, Tuple, Union import torch @@ -20,7 +19,7 @@ from lightning.fabric.utilities.types import _DEVICE -def _determine_root_gpu_device(gpus: list[_DEVICE]) -> Optional[_DEVICE]: +def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: """ Args: gpus: Non-empty list of ints representing which GPUs to use @@ -47,10 +46,10 @@ def _determine_root_gpu_device(gpus: list[_DEVICE]) -> Optional[_DEVICE]: def _parse_gpu_ids( - gpus: Optional[Union[int, str, list[int]]], + gpus: Optional[Union[int, str, List[int]]], include_cuda: bool = False, include_mps: bool = False, -) -> Optional[list[int]]: +) -> Optional[List[int]]: """Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`. Args: @@ -103,7 +102,7 @@ def _parse_gpu_ids( return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) -def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[int, list[int]]: +def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: if not isinstance(s, str): return s if s == "-1": @@ -113,7 +112,7 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: bool = False) -> list[int]: +def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -140,8 +139,8 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool -) -> Optional[list[int]]: + gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool +) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): return list(gpus) @@ -155,7 +154,7 @@ def _normalize_parse_gpu_input_to_list( return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> list[int]: +def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: """ Returns: A list of all available GPUs @@ -168,7 +167,7 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals return cuda_gpus + mps_gpus -def _check_unique(device_ids: list[int]) -> None: +def _check_unique(device_ids: List[int]) -> None: """Checks that the device_ids are unique. Args: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index ec4eb26..0e6c52d 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -4,11 +4,10 @@ import os import signal import time -from collections.abc import Iterable, Iterator, Sized from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union import torch import torch.nn.functional as F @@ -100,7 +99,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim return all_found -def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Tensor]: +def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case @@ -154,7 +153,7 @@ def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Ten return gathered_result -def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> list[Tensor]: +def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) return gathered_result @@ -346,7 +345,7 @@ def __init__(self, sampler: Union[Sampler, Iterable]) -> None: ) self._sampler = sampler # defer materializing an iterator until it is necessary - self._sampler_list: Optional[list[Any]] = None + self._sampler_list: Optional[List[Any]] = None @override def __getitem__(self, index: int) -> Any: diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 2760c6b..c92dfd8 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Sequence, Union import torch from torch.nn import Module, Parameter @@ -47,7 +46,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[dict] = None, + kwargs: Optional[Dict] = None, ) -> Any: kwargs = kwargs or {} if not self.enabled: diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 9e158c1..a1c3b69 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -13,12 +13,10 @@ import os import pickle import warnings -from collections import OrderedDict -from collections.abc import Sequence from functools import partial from io import BytesIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -136,7 +134,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[dict] = None, + kwargs: Optional[Dict] = None, ) -> Any: kwargs = kwargs or {} loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] @@ -221,7 +219,7 @@ def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: def _move_state_into( - source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None + source: Dict[str, Any], destination: Dict[str, Union[Any, _Stateful]], keys: Optional[Set[str]] = None ) -> None: """Takes the state from the source destination and moves it into the destination dictionary. @@ -237,7 +235,7 @@ def _move_state_into( destination[key] = state -def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]: +def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: """Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict. The current implementation assumes that the entire checkpoint fits in CPU memory. @@ -250,7 +248,7 @@ def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]: from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from torch.distributed.checkpoint.state_dict_loader import _load_state_dict - checkpoint: dict[str, Any] = {} + checkpoint: Dict[str, Any] = {} _load_state_dict( checkpoint, storage_reader=FileSystemReader(checkpoint_folder), diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index dd2b0a3..07b76ad 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -15,16 +15,15 @@ import inspect import json from argparse import Namespace -from collections.abc import Mapping, MutableMapping from dataclasses import asdict, is_dataclass -from typing import Any, Optional, Union +from typing import Any, Dict, Mapping, MutableMapping, Optional, Union from torch import Tensor from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE -def _convert_params(params: Optional[Union[dict[str, Any], Namespace]]) -> dict[str, Any]: +def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[str, Any]: """Ensure parameters are a dict or convert to dict if necessary. Args: @@ -44,7 +43,7 @@ def _convert_params(params: Optional[Union[dict[str, Any], Namespace]]) -> dict[ return params -def _sanitize_callable_params(params: dict[str, Any]) -> dict[str, Any]: +def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. Args: @@ -74,7 +73,7 @@ def _sanitize_callable(val: Any) -> Any: return {key: _sanitize_callable(val) for key, val in params.items()} -def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> dict[str, Any]: +def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> Dict[str, Any]: """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. Args: @@ -93,7 +92,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent {'5/a': 123} """ - result: dict[str, Any] = {} + result: Dict[str, Any] = {} for k, v in params.items(): new_key = parent_key + delimiter + str(k) if parent_key else str(k) if is_dataclass(v) and not isinstance(v, type): @@ -108,7 +107,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent return result -def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: +def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: """Returns params with non-primitvies converted to strings for logging. >>> import torch @@ -141,7 +140,7 @@ def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: return params -def _convert_json_serializable(params: dict[str, Any]) -> dict[str, Any]: +def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]: """Convert non-serializable objects in params to string.""" return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()} diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index df83f9b..2c57ec9 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable, MutableMapping +from collections.abc import MutableMapping +from typing import Iterable from torch import Tensor from torch.optim import Optimizer diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 7d8f6ca..9ad0f90 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -15,7 +15,7 @@ from importlib.metadata import entry_points from inspect import getmembers, isclass from types import ModuleType -from typing import Any, Union +from typing import Any, List, Type, Union from lightning_utilities import is_overridden @@ -24,7 +24,7 @@ _log = logging.getLogger(__name__) -def _load_external_callbacks(group: str) -> list[Any]: +def _load_external_callbacks(group: str) -> List[Any]: """Collect external callbacks registered through entry points. The entry points are expected to be functions returning a list of callbacks. @@ -40,10 +40,10 @@ def _load_external_callbacks(group: str) -> list[Any]: entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] ) - external_callbacks: list[Any] = [] + external_callbacks: List[Any] = [] for factory in factories: callback_factory = factory.load() - callbacks_list: Union[list[Any], Any] = callback_factory() + callbacks_list: Union[List[Any], Any] = callback_factory() callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list if callbacks_list: _log.info( @@ -54,7 +54,7 @@ def _load_external_callbacks(group: str) -> list[Any]: return external_callbacks -def _register_classes(registry: Any, method: str, module: ModuleType, parent: type[object]) -> None: +def _register_classes(registry: Any, method: str, module: ModuleType, parent: Type[object]) -> None: for _, member in getmembers(module, isclass): if issubclass(member, parent) and is_overridden(method, member, parent): register_fn = getattr(member, method) diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index f9c0dde..a2d6278 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -3,7 +3,7 @@ import random from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Optional +from typing import Any, Dict, List, Optional import torch @@ -104,13 +104,10 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: if _NUMPY_AVAILABLE: import numpy as np - ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) - np_rng_seed = ss.generate_state(4) + np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only - np.random.seed(np_rng_seed) - -def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]: +def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]: """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG) algorithm.""" # Combine base seed, worker id and rank into a unique 64-bit number @@ -123,7 +120,7 @@ def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, co return seeds -def _collect_rng_states(include_cuda: bool = True) -> dict[str, Any]: +def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" states = { "torch": torch.get_rng_state(), @@ -138,7 +135,7 @@ def _collect_rng_states(include_cuda: bool = True) -> dict[str, Any]: return states -def _set_rng_states(rng_state_dict: dict[str, Any]) -> None: +def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: r"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current process.""" torch.set_rng_state(rng_state_dict["torch"]) diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index 04c5544..5dca599 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -2,7 +2,7 @@ import operator import os import warnings -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import torch from lightning_utilities.core.imports import compare_version @@ -66,7 +66,7 @@ def __init__( self.warmup = warmup self.atol = atol self.rtol = rtol - self.bad_batches: list[int] = [] + self.bad_batches: List[int] = [] self.exclude_batches_path = exclude_batches_path self.finite_only = finite_only @@ -147,7 +147,7 @@ def _update_stats(self, val: torch.Tensor) -> None: self.running_mean.update(val) self.last_val = val - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return { "last_val": self.last_val.item() if isinstance(self.last_val, torch.Tensor) else self.last_val, "mode": self.mode, @@ -160,7 +160,7 @@ def state_dict(self) -> dict[str, Any]: "mean": self.running_mean.base_metric.state_dict(), } - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.last_val = state_dict.pop("last_val") self.mode = state_dict.pop("mode") self.warmup = state_dict.pop("warmup") diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6f5d933..6f05134 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -14,7 +14,7 @@ import operator import os import sys -from typing import Optional +from typing import Dict, List, Optional, Tuple import torch from lightning_utilities.core.imports import RequirementCache, compare_version @@ -40,7 +40,7 @@ def _runif_reasons( standalone: bool = False, deepspeed: bool = False, dynamo: bool = False, -) -> tuple[list[str], dict[str, bool]]: +) -> Tuple[List[str], Dict[str, bool]]: """Construct reasons for pytest skipif. Args: diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 72b33a4..598a322 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -13,7 +13,7 @@ # limitations under the License. # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820/composer/callbacks/speed_monitor.py from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, TypeVar, Union import torch from typing_extensions import override @@ -24,7 +24,7 @@ from lightning.fabric import Fabric from lightning.fabric.plugins import Precision -_THROUGHPUT_METRICS = dict[str, Union[int, float]] +_THROUGHPUT_METRICS = Dict[str, Union[int, float]] # The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's @@ -108,7 +108,7 @@ def __init__( self._batches: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) self._samples: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) self._lengths: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) - self._flops: deque[int] = deque(maxlen=window_size) + self._flops: Deque[int] = deque(maxlen=window_size) def update( self, @@ -302,7 +302,7 @@ def measure_flops( return flop_counter.get_total_flops() -_CUDA_FLOPS: dict[str, dict[Union[str, torch.dtype], float]] = { +_CUDA_FLOPS: Dict[str, Dict[Union[str, torch.dtype], float]] = { # Hopper # source: https://resources.nvidia.com/en-us-tensor-core "h100 nvl": { @@ -648,7 +648,7 @@ def _plugin_to_compute_dtype(plugin: "Precision") -> torch.dtype: T = TypeVar("T", bound=float) -class _MonotonicWindow(list[T]): +class _MonotonicWindow(List[T]): """Custom fixed size list that only supports right-append and ensures that all values increase monotonically.""" def __init__(self, maxlen: int) -> None: diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index 1d7235f..2e18dc8 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from collections.abc import Iterator from pathlib import Path from typing import ( Any, Callable, + DefaultDict, + Dict, + Iterator, + List, Optional, Protocol, TypeVar, @@ -36,7 +38,7 @@ _PATH = Union[str, Path] _DEVICE = Union[torch.device, str, int] _MAP_LOCATION_TYPE = Optional[ - Union[_DEVICE, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[_DEVICE, _DEVICE]] + Union[_DEVICE, Callable[[UntypedStorage, str], Optional[UntypedStorage]], Dict[_DEVICE, _DEVICE]] ] _PARAMETERS = Iterator[torch.nn.Parameter] @@ -55,9 +57,9 @@ class _Stateful(Protocol[_DictKey]): """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - def state_dict(self) -> dict[_DictKey, Any]: ... + def state_dict(self) -> Dict[_DictKey, Any]: ... - def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ... + def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: ... @runtime_checkable @@ -84,10 +86,10 @@ def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: class Optimizable(Steppable, Protocol): """To structurally type ``optimizer``""" - param_groups: list[dict[Any, Any]] - defaults: dict[Any, Any] - state: defaultdict[Tensor, Any] + param_groups: List[Dict[Any, Any]] + defaults: Dict[Any, Any] + state: DefaultDict[Tensor, Any] - def state_dict(self) -> dict[str, dict[Any, Any]]: ... + def state_dict(self) -> Dict[str, Dict[Any, Any]]: ... - def load_state_dict(self, state_dict: dict[str, dict[Any, Any]]) -> None: ... + def load_state_dict(self, state_dict: Dict[str, Dict[Any, Any]]) -> None: ... diff --git a/src/lightning/fabric/utilities/warnings.py b/src/lightning/fabric/utilities/warnings.py index b62bece..62e5f5f 100644 --- a/src/lightning/fabric/utilities/warnings.py +++ b/src/lightning/fabric/utilities/warnings.py @@ -15,7 +15,7 @@ import warnings from pathlib import Path -from typing import Optional, Union +from typing import Optional, Type, Union from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning @@ -38,7 +38,7 @@ def disable_possible_user_warnings(module: str = "") -> None: def _custom_format_warning( - message: Union[Warning, str], category: type[Warning], filename: str, lineno: int, line: Optional[str] = None + message: Union[Warning, str], category: Type[Warning], filename: str, lineno: int, line: Optional[str] = None ) -> str: """Custom formatting that avoids an extra line in case warnings are emitted from the `rank_zero`-functions.""" if _is_path_in_lightning(Path(filename)): diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index b593c9f..c57f197 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from collections.abc import Generator, Iterator, Mapping from copy import deepcopy from functools import partial, wraps from types import MethodType from typing import ( Any, Callable, + Dict, + Generator, + Iterator, + List, + Mapping, Optional, + Tuple, TypeVar, Union, overload, @@ -43,14 +48,14 @@ from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.types import Optimizable -T_destination = TypeVar("T_destination", bound=dict[str, Any]) +T_destination = TypeVar("T_destination", bound=Dict[str, Any]) _LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step") _in_fabric_backward: bool = False class _FabricOptimizer: - def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[list[Callable]] = None) -> None: + def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None: """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer step calls to the strategy. @@ -71,10 +76,10 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional def optimizer(self) -> Optimizer: return self._optimizer - def state_dict(self) -> dict[str, Tensor]: + def state_dict(self) -> Dict[str, Tensor]: return self._strategy.get_optimizer_state(self.optimizer) - def load_state_dict(self, state_dict: dict[str, Tensor]) -> None: + def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None: self.optimizer.load_state_dict(state_dict) def step(self, closure: Optional[Callable] = None) -> Any: @@ -144,12 +149,12 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... @overload - def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]: ... + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... @override def state_dict( self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False - ) -> Optional[dict[str, Any]]: + ) -> Optional[Dict[str, Any]]: return self._original_module.state_dict( destination=destination, # type: ignore[type-var] prefix=prefix, @@ -345,7 +350,7 @@ def _unwrap( return apply_to_collection(collection, dtype=tuple(types), function=_unwrap) -def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> tuple[Union[Any, nn.Module], Optional[dict[str, Any]]]: +def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]: """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped. Use this function before instance checks against e.g. :class:`_FabricModule`. @@ -361,7 +366,7 @@ def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> tuple[Union[Any, nn.Mo return obj, None -def _to_compiled(module: nn.Module, compile_kwargs: dict[str, Any]) -> OptimizedModule: +def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> OptimizedModule: return torch.compile(module, **compile_kwargs) # type: ignore[return-value] diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 9238071..0490c2d 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any +from typing import Any, Dict import lightning.pytorch as pl from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator @@ -34,7 +34,7 @@ def setup(self, trainer: "pl.Trainer") -> None: """ - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Get stats for a given device. Args: diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 525071c..a85a959 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any, Dict, List, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -38,7 +38,7 @@ def setup_device(self, device: torch.device) -> None: raise MisconfigurationException(f"Device should be CPU, got {device} instead.") @override - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Get CPU stats from ``psutil`` package.""" return get_cpu_stats() @@ -54,7 +54,7 @@ def parse_devices(devices: Union[int, str]) -> int: @staticmethod @override - def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -89,7 +89,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _PSUTIL_AVAILABLE = RequirementCache("psutil") -def get_cpu_stats() -> dict[str, float]: +def get_cpu_stats() -> Dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching CPU device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index a00b12a..6df3bc6 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -15,7 +15,7 @@ import os import shutil import subprocess -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from typing_extensions import override @@ -61,7 +61,7 @@ def set_nvidia_flags(local_rank: int) -> None: _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") @override - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Gets stats for the given GPU device. Args: @@ -83,13 +83,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_cuda=True) @staticmethod @override - def get_parallel_devices(devices: list[int]) -> list[torch.device]: + def get_parallel_devices(devices: List[int]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @@ -114,7 +114,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover +def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index f767498..6efe629 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from typing_extensions import override @@ -43,7 +43,7 @@ def setup_device(self, device: torch.device) -> None: raise MisconfigurationException(f"Device should be MPS, got {device} instead.") @override - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Get M1 (cpu + gpu) stats from ``psutil`` package.""" return get_device_stats() @@ -53,13 +53,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_mps=True) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: + def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None @@ -94,7 +94,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _SWAP_PERCENT = "M1_swap_percent" -def get_device_stats() -> dict[str, float]: +def get_device_stats() -> Dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching MPS device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index 10726b5..01ef722 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Dict from typing_extensions import override @@ -29,7 +29,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator): """ @override - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """Gets stats for the given XLA device. Args: diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 3bfb609..9311d49 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -13,7 +13,7 @@ # limitations under the License. r"""Base class used to build new callbacks.""" -from typing import Any +from typing import Any, Dict, Type from torch import Tensor from torch.optim import Optimizer @@ -41,7 +41,7 @@ def state_key(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_key(self) -> type["Callback"]: + def _legacy_state_key(self) -> Type["Callback"]: """State key for checkpoints saved prior to version 1.5.0.""" return type(self) @@ -229,7 +229,7 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: """Called when any trainer execution is interrupted by an exception.""" - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: """Called when saving a checkpoint, implement to generate callback's ``state_dict``. Returns: @@ -238,7 +238,7 @@ def state_dict(self) -> dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``. Args: @@ -248,7 +248,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: pass def on_save_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save. @@ -260,7 +260,7 @@ def on_save_checkpoint( """ def on_load_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: r"""Called when loading a model checkpoint, use to reload state. diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index 6279dd1..64ea47d 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -19,7 +19,7 @@ """ -from typing import Any, Optional +from typing import Any, Dict, Optional from typing_extensions import override @@ -158,5 +158,5 @@ def on_test_batch_end( self._get_and_log_device_stats(trainer, "on_test_batch_end") -def _prefix_metric_keys(metrics_dict: dict[str, float], prefix: str, separator: str) -> dict[str, float]: +def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: return {prefix + separator + k: v for k, v in metrics_dict.items()} diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 78c4215..d1212fe 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -20,7 +20,7 @@ """ import logging -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import Tensor @@ -139,7 +139,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s # validation, then we run after validation instead of on train epoch end self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def _validate_condition_metric(self, logs: dict[str, Tensor]) -> bool: + def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool: monitor_val = logs.get(self.monitor) error_msg = ( @@ -163,7 +163,7 @@ def monitor_op(self) -> Callable: return self.mode_dict[self.mode] @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return { "wait_count": self.wait_count, "stopped_epoch": self.stopped_epoch, @@ -172,7 +172,7 @@ def state_dict(self) -> dict[str, Any]: } @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.wait_count = state_dict["wait_count"] self.stopped_epoch = state_dict["stopped_epoch"] self.best_score = state_dict["best_score"] @@ -215,7 +215,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: if reason and self.verbose: self._log_info(trainer, reason, self.log_rank_zero_only) - def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[str]]: + def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[str]]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 356ab22..46a9098 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -19,8 +19,7 @@ """ import logging -from collections.abc import Generator, Iterable -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch from torch.nn import Module, ModuleDict @@ -86,17 +85,17 @@ class BaseFinetuning(Callback): """ def __init__(self) -> None: - self._internal_optimizer_metadata: dict[int, list[dict[str, Any]]] = {} + self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {} self._restarting = False @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, } @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._restarting = True if "internal_optimizer_metadata" in state_dict: # noqa: SIM401 self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"] @@ -117,7 +116,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self._restarting = False @staticmethod - def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> list[Module]: + def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: """This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves. @@ -216,7 +215,7 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: BaseFinetuning.freeze_module(mod) @staticmethod - def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> list: + def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: """This function is used to exclude any parameter which already exists in this optimizer. Args: @@ -286,7 +285,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s ) @staticmethod - def _apply_mapping_to_param_groups(param_groups: list[dict[str, Any]], mapping: dict) -> list[dict[str, Any]]: + def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: output = [] for g in param_groups: # skip params to save memory @@ -300,7 +299,7 @@ def _store( pl_module: "pl.LightningModule", opt_idx: int, num_param_groups: int, - current_param_groups: list[dict[str, Any]], + current_param_groups: List[Dict[str, Any]], ) -> None: mapping = {p: n for n, p in pl_module.named_parameters()} if opt_idx not in self._internal_optimizer_metadata: @@ -388,14 +387,14 @@ def __init__( self.previous_backbone_lr: Optional[float] = None @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, "previous_backbone_lr": self.previous_backbone_lr, } @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.previous_backbone_lr = state_dict["previous_backbone_lr"] super().load_state_dict(state_dict) diff --git a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py index 50ddc10..20b1df2 100644 --- a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py +++ b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py @@ -20,7 +20,7 @@ """ -from typing import Any +from typing import Any, Dict from typing_extensions import override @@ -64,7 +64,7 @@ class GradientAccumulationScheduler(Callback): """ - def __init__(self, scheduling: dict[int, int]): + def __init__(self, scheduling: Dict[int, int]): super().__init__() if not scheduling: # empty dict error diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index ca2b4a8..6a94c7e 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -22,7 +22,7 @@ import itertools from collections import defaultdict -from typing import Any, Literal, Optional +from typing import Any, DefaultDict, Dict, List, Literal, Optional, Set, Tuple, Type import torch from torch.optim.optimizer import Optimizer @@ -104,9 +104,9 @@ def __init__( self.log_momentum = log_momentum self.log_weight_decay = log_weight_decay - self.lrs: dict[str, list[float]] = {} - self.last_momentum_values: dict[str, Optional[list[float]]] = {} - self.last_weight_decay_values: dict[str, Optional[list[float]]] = {} + self.lrs: Dict[str, List[float]] = {} + self.last_momentum_values: Dict[str, Optional[List[float]]] = {} + self.last_weight_decay_values: Dict[str, Optional[List[float]]] = {} @override def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: @@ -141,7 +141,7 @@ def _check_no_key(key: str) -> bool: ) # Find names for schedulers - names: list[list[str]] = [] + names: List[List[str]] = [] ( sched_hparam_keys, optimizers_with_scheduler, @@ -186,7 +186,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) for logger in trainer.loggers: logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) - def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> dict[str, float]: + def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: latest_stat = {} ( @@ -219,7 +219,7 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> dict[str, floa return latest_stat - def _get_optimizer_stats(self, optimizer: Optimizer, names: list[str]) -> dict[str, float]: + def _get_optimizer_stats(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]: stats = {} param_groups = optimizer.param_groups use_betas = "betas" in optimizer.defaults @@ -236,12 +236,12 @@ def _get_optimizer_stats(self, optimizer: Optimizer, names: list[str]) -> dict[s return stats - def _extract_lr(self, param_group: dict[str, Any], name: str) -> dict[str, Any]: + def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: lr = param_group["lr"] self.lrs[name].append(lr) return {name: lr} - def _remap_keys(self, names: list[list[str]], token: str = "/pg1") -> None: + def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: """This function is used the remap the keys if param groups for a given optimizer increased.""" for group_new_names in names: for new_name in group_new_names: @@ -251,7 +251,7 @@ def _remap_keys(self, names: list[list[str]], token: str = "/pg1") -> None: elif new_name not in self.lrs: self.lrs[new_name] = [] - def _extract_momentum(self, param_group: dict[str, list], name: str, use_betas: bool) -> dict[str, float]: + def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]: if not self.log_momentum: return {} @@ -259,7 +259,7 @@ def _extract_momentum(self, param_group: dict[str, list], name: str, use_betas: self.last_momentum_values[name] = momentum return {name: momentum} - def _extract_weight_decay(self, param_group: dict[str, Any], name: str) -> dict[str, Any]: + def _extract_weight_decay(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: """Extracts the weight decay statistics from a parameter group.""" if not self.log_weight_decay: return {} @@ -269,14 +269,14 @@ def _extract_weight_decay(self, param_group: dict[str, Any], name: str) -> dict[ return {name: weight_decay} def _add_prefix( - self, name: str, optimizer_cls: type[Optimizer], seen_optimizer_types: defaultdict[type[Optimizer], int] + self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int] ) -> str: if optimizer_cls not in seen_optimizer_types: return name count = seen_optimizer_types[optimizer_cls] return name + f"-{count - 1}" if count > 1 else name - def _add_suffix(self, name: str, param_groups: list[dict], param_group_index: int, use_names: bool = True) -> str: + def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: if len(param_groups) > 1: if not use_names: return f"{name}/pg{param_group_index + 1}" @@ -287,7 +287,7 @@ def _add_suffix(self, name: str, param_groups: list[dict], param_group_index: in return f"{name}/{pg_name}" if pg_name else name return name - def _duplicate_param_group_names(self, param_groups: list[dict]) -> set[str]: + def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)] unique = set(names) if len(names) == len(unique): @@ -296,13 +296,13 @@ def _duplicate_param_group_names(self, param_groups: list[dict]) -> set[str]: def _find_names_from_schedulers( self, - lr_scheduler_configs: list[LRSchedulerConfig], - ) -> tuple[list[list[str]], list[Optimizer], defaultdict[type[Optimizer], int]]: + lr_scheduler_configs: List[LRSchedulerConfig], + ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] - seen_optimizers: list[Optimizer] = [] - seen_optimizer_types: defaultdict[type[Optimizer], int] = defaultdict(int) + seen_optimizers: List[Optimizer] = [] + seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int) for config in lr_scheduler_configs: sch = config.scheduler name = config.name if config.name is not None else "lr-" + sch.optimizer.__class__.__name__ @@ -316,10 +316,10 @@ def _find_names_from_schedulers( def _find_names_from_optimizers( self, - optimizers: list[Any], - seen_optimizers: list[Optimizer], - seen_optimizer_types: defaultdict[type[Optimizer], int], - ) -> tuple[list[list[str]], list[Optimizer]]: + optimizers: List[Any], + seen_optimizers: List[Optimizer], + seen_optimizer_types: DefaultDict[Type[Optimizer], int], + ) -> Tuple[List[List[str]], List[Optimizer]]: names = [] optimizers_without_scheduler = [] @@ -342,10 +342,10 @@ def _check_duplicates_and_update_name( self, optimizer: Optimizer, name: str, - seen_optimizers: list[Optimizer], - seen_optimizer_types: defaultdict[type[Optimizer], int], + seen_optimizers: List[Optimizer], + seen_optimizer_types: DefaultDict[Type[Optimizer], int], lr_scheduler_config: Optional[LRSchedulerConfig], - ) -> list[str]: + ) -> List[str]: seen_optimizers.append(optimizer) optimizer_cls = type(optimizer) if lr_scheduler_config is None or lr_scheduler_config.name is None: diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 85bfb65..9587da0 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,7 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Dict, Literal, Optional, Set, Union from weakref import proxy import torch @@ -241,7 +241,7 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None - self.best_k_models: dict[str, Tensor] = {} + self.best_k_models: Dict[str, Tensor] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None self.best_model_path = "" @@ -335,7 +335,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_last_checkpoint(trainer, monitor_candidates) @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, @@ -349,7 +349,7 @@ def state_dict(self) -> dict[str, Any]: } @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath) if self.dirpath == dirpath_from_ckpt: @@ -367,7 +367,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.best_model_path = state_dict["best_model_path"] - def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: + def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: if self.save_top_k == 0: return @@ -533,7 +533,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = def _format_checkpoint_name( self, filename: Optional[str], - metrics: dict[str, Tensor], + metrics: Dict[str, Tensor], prefix: str = "", auto_insert_metric_name: bool = True, ) -> str: @@ -567,7 +567,7 @@ def _format_checkpoint_name( return filename def format_checkpoint_name( - self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None ) -> str: """Generate a filename according to the defined template. @@ -637,7 +637,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: return ckpt_path - def _find_last_checkpoints(self, trainer: "pl.Trainer") -> set[str]: + def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: # find all checkpoints in the folder ckpt_path = self.__resolve_ckpt_dir(trainer) last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?" @@ -654,7 +654,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") def _get_metric_interpolated_filepath_name( - self, monitor_candidates: dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None + self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None ) -> str: filepath = self.format_checkpoint_name(monitor_candidates) @@ -666,7 +666,7 @@ def _get_metric_interpolated_filepath_name( return filepath - def _monitor_candidates(self, trainer: "pl.Trainer") -> dict[str, Tensor]: + def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: monitor_candidates = deepcopy(trainer.callback_metrics) # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor # or does not exist we overwrite it as it's likely an error @@ -676,7 +676,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> dict[str, Tensor]: monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step) return monitor_candidates - def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: + def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: if not self.save_last: return @@ -697,7 +697,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[ if previous and self._should_remove_checkpoint(trainer, previous, filepath): self._remove_checkpoint(trainer, previous) - def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: + def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: assert self.monitor current = monitor_candidates.get(self.monitor) if self.check_monitor_top_k(trainer, current): @@ -708,7 +708,7 @@ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: di step = monitor_candidates["step"] rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: + def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath @@ -718,7 +718,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate self._remove_checkpoint(trainer, previous) def _update_best_and_save( - self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor] + self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] ) -> None: k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index 03f50d6..89c31b2 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -23,7 +23,7 @@ """ import logging -from typing import Any, Union +from typing import Any, Dict, List, Tuple, Union from typing_extensions import override @@ -54,7 +54,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: self._max_depth: int = max_depth - self._summarize_kwargs: dict[str, Any] = summarize_kwargs + self._summarize_kwargs: Dict[str, Any] = summarize_kwargs @override def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -87,11 +87,11 @@ def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Un @staticmethod def summarize( - summary_data: list[tuple[str, list[str]]], + summary_data: List[Tuple[str, List[str]]], total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: dict[str, int], + total_training_modes: Dict[str, int], **summarize_kwargs: Any, ) -> None: summary_table = _format_summary_table( diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index ce6342c..7f782fb 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -18,8 +18,7 @@ Aids in saving predictions """ -from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Sequence from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 7cf6993..785bf65 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -176,7 +176,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s def get_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> dict[str, Union[int, str, float, dict[str, float]]]: + ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. @@ -207,7 +207,7 @@ def get_metrics(self, trainer, model): return {**standard_metrics, **pbar_metrics} -def get_standard_metrics(trainer: "pl.Trainer") -> dict[str, Union[int, str]]: +def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: r"""Returns the standard metrics displayed in the progress bar. Currently, it only includes the version of the experiment when using a logger. @@ -219,7 +219,7 @@ def get_standard_metrics(trainer: "pl.Trainer") -> dict[str, Union[int, str]]: Dictionary with the standard metrics to be displayed in the progress bar. """ - items_dict: dict[str, Union[int, str]] = {} + items_dict: Dict[str, Union[int, str]] = {} if trainer.loggers: from lightning.pytorch.loggers.utilities import _version diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 0a51d99..896de71 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union, cast +from typing import Any, Dict, Generator, Optional, Union, cast from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -147,15 +146,15 @@ def __init__( metrics_format: str, ): self._trainer = trainer - self._tasks: dict[Union[int, TaskID], Any] = {} + self._tasks: Dict[Union[int, TaskID], Any] = {} self._current_task_id = 0 - self._metrics: dict[Union[str, Style], Any] = {} + self._metrics: Dict[Union[str, Style], Any] = {} self._style = style self._text_delimiter = text_delimiter self._metrics_format = metrics_format super().__init__() - def update(self, metrics: dict[Any, Any]) -> None: + def update(self, metrics: Dict[Any, Any]) -> None: # Called when metrics are ready to be rendered. # This is to prevent render from causing deadlock issues by requesting metrics # in separate threads. @@ -258,7 +257,7 @@ def __init__( refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), - console_kwargs: Optional[dict[str, Any]] = None, + console_kwargs: Optional[Dict[str, Any]] = None, ) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( @@ -643,7 +642,7 @@ def configure_columns(self, trainer: "pl.Trainer") -> list: ProcessingSpeedColumn(style=self.theme.processing_speed), ] - def __getstate__(self) -> dict: + def __getstate__(self) -> Dict: state = self.__dict__.copy() # both the console and progress object can hold thread lock objects that are not pickleable state["progress"] = None diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 4ef260f..cf9cd71 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -15,7 +15,7 @@ import math import os import sys -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -115,7 +115,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool self._predict_progress_bar: Optional[_tqdm] = None self._leave = leave - def __getstate__(self) -> dict: + def __getstate__(self) -> Dict: # can't pickle the tqdm objects return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 1517ef6..e83a9de 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -18,10 +18,9 @@ import inspect import logging -from collections.abc import Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch.nn.utils.prune as pytorch_prune from lightning_utilities.core.apply_func import apply_to_collection @@ -50,14 +49,14 @@ "random_unstructured": pytorch_prune.RandomUnstructured, } -_PARAM_TUPLE = tuple[nn.Module, str] +_PARAM_TUPLE = Tuple[nn.Module, str] _PARAM_LIST = Sequence[_PARAM_TUPLE] _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) class _LayerRef(TypedDict): data: nn.Module - names: list[tuple[int, str]] + names: List[Tuple[int, str]] class ModelPruning(Callback): @@ -67,7 +66,7 @@ def __init__( self, pruning_fn: Union[Callable, str], parameters_to_prune: _PARAM_LIST = (), - parameter_names: Optional[list[str]] = None, + parameter_names: Optional[List[str]] = None, use_global_unstructured: bool = True, amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, apply_pruning: Union[bool, Callable[[int], bool]] = True, @@ -166,8 +165,8 @@ def __init__( self._resample_parameters = resample_parameters self._prune_on_train_epoch_end = prune_on_train_epoch_end self._parameter_names = parameter_names or self.PARAMETER_NAMES - self._global_kwargs: dict[str, Any] = {} - self._original_layers: Optional[dict[int, _LayerRef]] = None + self._global_kwargs: Dict[str, Any] = {} + self._original_layers: Optional[Dict[int, _LayerRef]] = None self._pruning_method_name: Optional[str] = None for name in self._parameter_names: @@ -311,7 +310,7 @@ def _apply_local_pruning(self, amount: float) -> None: for module, name in self._parameters_to_prune: self.pruning_fn(module, name=name, amount=amount) # type: ignore[call-arg] - def _resolve_global_kwargs(self, amount: float) -> dict[str, Any]: + def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]: self._global_kwargs["amount"] = amount params = set(inspect.signature(self.pruning_fn).parameters) params.discard("self") @@ -323,7 +322,7 @@ def _apply_global_pruning(self, amount: float) -> None: ) @staticmethod - def _get_pruned_stats(module: nn.Module, name: str) -> tuple[int, int]: + def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: attr = f"{name}_mask" if not hasattr(module, attr): return 0, 1 @@ -346,7 +345,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None: @rank_zero_only def _log_sparsity_stats( - self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0 + self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 ) -> None: total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) prev_total_zeros = sum(zeros for zeros, _ in prev) @@ -415,7 +414,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> Non rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint") self.make_pruning_permanent(pl_module) - def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> dict[str, Any]: + def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> Dict[str, Any]: state_dict = pl_module.state_dict() # find the mask and the original weights. @@ -433,7 +432,7 @@ def move_to_cpu(tensor: Tensor) -> Tensor: return apply_to_collection(state_dict, Tensor, move_to_cpu) @override - def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: dict[str, Any]) -> None: + def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint") # manually prune the weights so training can keep going with the same buffers diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index e4027f0..c6c429b 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Dict, List, Tuple from typing_extensions import override @@ -67,11 +67,11 @@ def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: @staticmethod @override def summarize( - summary_data: list[tuple[str, list[str]]], + summary_data: List[Tuple[str, List[str]]], total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: dict[str, int], + total_training_modes: Dict[str, int], **summarize_kwargs: Any, ) -> None: from rich import get_console diff --git a/src/lightning/pytorch/callbacks/spike.py b/src/lightning/pytorch/callbacks/spike.py index b006acd..725d6f6 100644 --- a/src/lightning/pytorch/callbacks/spike.py +++ b/src/lightning/pytorch/callbacks/spike.py @@ -1,6 +1,5 @@ import os -from collections.abc import Mapping -from typing import Any, Union +from typing import Any, Mapping, Union import torch diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 5643a03..737084c 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -17,7 +17,7 @@ """ from copy import deepcopy -from typing import Any, Callable, Literal, Optional, Union, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast import torch from torch import Tensor, nn @@ -39,7 +39,7 @@ class StochasticWeightAveraging(Callback): def __init__( self, - swa_lrs: Union[float, list[float]], + swa_lrs: Union[float, List[float]], swa_epoch_start: Union[int, float] = 0.8, annealing_epochs: int = 10, annealing_strategy: Literal["cos", "linear"] = "cos", @@ -126,10 +126,10 @@ def __init__( self._average_model: Optional[pl.LightningModule] = None self._initialized = False self._swa_scheduler: Optional[LRScheduler] = None - self._scheduler_state: Optional[dict] = None + self._scheduler_state: Optional[Dict] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 - self.momenta: dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} + self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} self._max_epochs: int @property @@ -331,7 +331,7 @@ def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averag return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, @@ -340,7 +340,7 @@ def state_dict(self) -> dict[str, Any]: } @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._init_n_averaged = state_dict["n_averaged"] self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a49610a..a2d73d8 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union import torch from typing_extensions import override @@ -84,9 +84,9 @@ def __init__( self.batch_size_fn = batch_size_fn self.length_fn = length_fn self.available_flops: Optional[int] = None - self._throughputs: dict[RunningStage, Throughput] = {} - self._t0s: dict[RunningStage, float] = {} - self._lengths: dict[RunningStage, int] = {} + self._throughputs: Dict[RunningStage, Throughput] = {} + self._t0s: Dict[RunningStage, float] = {} + self._lengths: Dict[RunningStage, int] = {} @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py index b6b74d2..e1bed4a 100644 --- a/src/lightning/pytorch/callbacks/timer.py +++ b/src/lightning/pytorch/callbacks/timer.py @@ -20,7 +20,7 @@ import re import time from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -83,7 +83,7 @@ class Timer(Callback): def __init__( self, - duration: Optional[Union[str, timedelta, dict[str, int]]] = None, + duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, interval: str = Interval.step, verbose: bool = True, ) -> None: @@ -111,8 +111,8 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} - self._end_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -187,11 +187,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) - self._check_time_remaining(trainer) @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}} @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index c79f248..26af335 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -14,10 +14,9 @@ import inspect import os import sys -from collections.abc import Iterable from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union import torch import yaml @@ -66,11 +65,11 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any # LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch: LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] +LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]] # Type aliases intended for convenience of CLI developers -ArgsType = Optional[Union[list[str], dict[str, Any], Namespace]] +ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]] OptimizerCallable = Callable[[Iterable], Optimizer] LRSchedulerCallable = Callable[[Optimizer], Union[LRScheduler, ReduceLROnPlateau]] @@ -100,24 +99,24 @@ def __init__( if not _JSONARGPARSE_SIGNATURES_AVAILABLE: raise ModuleNotFoundError(f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}") super().__init__(*args, description=description, env_prefix=env_prefix, default_env=default_env, **kwargs) - self.callback_keys: list[str] = [] + self.callback_keys: List[str] = [] # separate optimizers and lr schedulers to know which were added - self._optimizers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} - self._lr_schedulers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} + self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} def add_lightning_class_args( self, lightning_class: Union[ Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], - type[Trainer], - type[LightningModule], - type[LightningDataModule], - type[Callback], + Type[Trainer], + Type[LightningModule], + Type[LightningDataModule], + Type[Callback], ], nested_key: str, subclass_mode: bool = False, required: bool = True, - ) -> list[str]: + ) -> List[str]: """Adds arguments from a lightning class to a nested key of the parser. Args: @@ -154,7 +153,7 @@ def add_lightning_class_args( def add_optimizer_args( self, - optimizer_class: Union[type[Optimizer], tuple[type[Optimizer], ...]] = (Optimizer,), + optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: @@ -170,7 +169,7 @@ def add_optimizer_args( assert all(issubclass(o, Optimizer) for o in optimizer_class) else: assert issubclass(optimizer_class, Optimizer) - kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} + kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) else: @@ -179,7 +178,7 @@ def add_optimizer_args( def add_lr_scheduler_args( self, - lr_scheduler_class: Union[LRSchedulerType, tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, + lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: @@ -196,7 +195,7 @@ def add_lr_scheduler_args( assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) - kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} + kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) else: @@ -306,14 +305,14 @@ class LightningCLI: def __init__( self, - model_class: Optional[Union[type[LightningModule], Callable[..., LightningModule]]] = None, - datamodule_class: Optional[Union[type[LightningDataModule], Callable[..., LightningDataModule]]] = None, - save_config_callback: Optional[type[SaveConfigCallback]] = SaveConfigCallback, - save_config_kwargs: Optional[dict[str, Any]] = None, - trainer_class: Union[type[Trainer], Callable[..., Trainer]] = Trainer, - trainer_defaults: Optional[dict[str, Any]] = None, + model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, + datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, + save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, + save_config_kwargs: Optional[Dict[str, Any]] = None, + trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, + trainer_defaults: Optional[Dict[str, Any]] = None, seed_everything_default: Union[bool, int] = True, - parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None, + parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, @@ -390,12 +389,11 @@ def __init__( self._add_instantiators() self.before_instantiate_classes() self.instantiate_classes() - self.after_instantiate_classes() if self.subcommand is not None: self._run_subcommand(self.subcommand) - def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + def _setup_parser_kwargs(self, parser_kwargs: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: subcommand_names = self.subcommands().keys() main_kwargs = {k: v for k, v in parser_kwargs.items() if k not in subcommand_names} subparser_kwargs = {k: v for k, v in parser_kwargs.items() if k in subcommand_names} @@ -411,12 +409,12 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: return parser def setup_parser( - self, add_subcommands: bool, main_kwargs: dict[str, Any], subparser_kwargs: dict[str, Any] + self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] ) -> None: """Initialize and setup the parser, subcommands, and arguments.""" self.parser = self.init_parser(**main_kwargs) if add_subcommands: - self._subcommand_method_arguments: dict[str, list[str]] = {} + self._subcommand_method_arguments: Dict[str, List[str]] = {} self._add_subcommands(self.parser, **subparser_kwargs) else: self._add_arguments(self.parser) @@ -471,7 +469,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """ @staticmethod - def subcommands() -> dict[str, set[str]]: + def subcommands() -> Dict[str, Set[str]]: """Defines the list of available subcommands and the arguments to skip.""" return { "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, @@ -482,7 +480,7 @@ def subcommands() -> dict[str, set[str]]: def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: """Adds subcommands to the input parser.""" - self._subcommand_parsers: dict[str, LightningArgumentParser] = {} + self._subcommand_parsers: Dict[str, LightningArgumentParser] = {} parser_subcommands = parser.add_subcommands() # the user might have passed a builder function trainer_class = ( @@ -499,11 +497,11 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No self._subcommand_parsers[subcommand] = subcommand_parser parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description) - def _prepare_subcommand_parser(self, klass: type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: + def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: parser = self.init_parser(**kwargs) self._add_arguments(parser) # subcommand arguments - skip: set[Union[str, int]] = set(self.subcommands()[subcommand]) + skip: Set[Union[str, int]] = set(self.subcommands()[subcommand]) added = parser.add_method_arguments(klass, subcommand, skip=skip) # need to save which arguments were added to pass them to the method later self._subcommand_method_arguments[subcommand] = added @@ -562,9 +560,6 @@ def instantiate_classes(self) -> None: self._add_configure_optimizers_method_to_model(self.subcommand) self.trainer = self.instantiate_trainer() - def after_instantiate_classes(self) -> None: - """Implement to run some code after instantiating the classes.""" - def instantiate_trainer(self, **kwargs: Any) -> Trainer: """Instantiates the trainer. @@ -576,7 +571,7 @@ def instantiate_trainer(self, **kwargs: Any) -> Trainer: trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} return self._instantiate_trainer(trainer_config, extra_callbacks) - def _instantiate_trainer(self, config: dict[str, Any], callbacks: list[Callback]) -> Trainer: + def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: key = "callbacks" if key in config: if config[key] is None: @@ -637,8 +632,8 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - parser = self._parser(subcommand) def get_automatic( - class_type: Union[type, tuple[type, ...]], register: dict[str, tuple[Union[type, tuple[type, ...]], str]] - ) -> list[str]: + class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] + ) -> List[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): @@ -709,7 +704,7 @@ def _run_subcommand(self, subcommand: str) -> None: if callable(after_fn): after_fn() - def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]: + def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: """Prepares the keyword arguments to pass to the subcommand to run.""" fn_kwargs = { k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand] @@ -735,26 +730,26 @@ def _set_seed(self) -> None: self.config["seed_everything"] = config_seed -def _class_path_from_class(class_type: type) -> str: +def _class_path_from_class(class_type: Type) -> str: return class_type.__module__ + "." + class_type.__name__ def _global_add_class_path( - class_type: type, init_args: Optional[Union[Namespace, dict[str, Any]]] = None -) -> dict[str, Any]: + class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None +) -> Dict[str, Any]: if isinstance(init_args, Namespace): init_args = init_args.as_dict() return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} -def _add_class_path_generator(class_type: type) -> Callable[[Namespace], dict[str, Any]]: - def add_class_path(init_args: Namespace) -> dict[str, Any]: +def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: + def add_class_path(init_args: Namespace) -> Dict[str, Any]: return _global_add_class_path(class_type, init_args) return add_class_path -def instantiate_class(args: Union[Any, tuple[Any, ...]], init: dict[str, Any]) -> Any: +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: """Instantiates a class with the given args and init. Args: @@ -795,7 +790,7 @@ def __init__(self, cli: LightningCLI, key: str) -> None: self.cli = cli self.key = key - def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: + def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: hparams = self.cli.config_dump.get(self.key, {}) if "class_path" in hparams: # To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the @@ -813,7 +808,7 @@ def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> M return class_type(*args, **kwargs) -def instantiate_module(class_type: type[ModuleType], config: dict[str, Any]) -> ModuleType: +def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: parser = ArgumentParser(exit_on_error=False) if "_class_path" in config: parser.add_subclass_arguments(class_type, "module", fail_untyped=False) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 0c7a984..6cb8f79 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -14,8 +14,7 @@ """LightningDataModule for loading DataLoaders with ease.""" import inspect -from collections.abc import Iterable -from typing import IO, Any, Optional, Union, cast +from typing import IO, Any, Dict, Iterable, Optional, Union, cast from lightning_utilities import apply_to_collection from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -148,7 +147,7 @@ def predict_dataloader() -> EVAL_DATALOADERS: datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign] return datamodule - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: """Called when saving a checkpoint, implement to generate and save datamodule state. Returns: @@ -157,7 +156,7 @@ def state_dict(self) -> dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict. Args: diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14..5495a02 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. """Various hooks to be used in the Lightning code.""" -from typing import Any, Optional +from typing import Any, Dict, Optional import torch from torch import Tensor @@ -670,7 +670,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): class CheckpointHooks: """Hooks to be used with Checkpointing.""" - def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. @@ -689,7 +689,7 @@ def on_load_checkpoint(self, checkpoint): """ - def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save. diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 3a01cd2..94ece00 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -15,10 +15,9 @@ import inspect import types from argparse import Namespace -from collections.abc import Iterator, MutableMapping, Sequence from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Optional, Union +from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union from lightning.fabric.utilities.data import AttributeDict from lightning.pytorch.utilities.parsing import save_hyperparameters @@ -42,7 +41,7 @@ def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator class HyperparametersMixin: - __jit_unused_properties__: list[str] = ["hparams", "hparams_initial"] + __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index f1d1da9..d8374ef 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -16,7 +16,6 @@ import logging import numbers import weakref -from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from io import BytesIO from pathlib import Path @@ -25,8 +24,14 @@ TYPE_CHECKING, Any, Callable, + Dict, + Generator, + List, Literal, + Mapping, Optional, + Sequence, + Tuple, Union, cast, overload, @@ -81,7 +86,7 @@ log = logging.getLogger(__name__) MODULE_OPTIMIZERS = Union[ - Optimizer, LightningOptimizer, _FabricOptimizer, list[Optimizer], list[LightningOptimizer], list[_FabricOptimizer] + Optimizer, LightningOptimizer, _FabricOptimizer, List[Optimizer], List[LightningOptimizer], List[_FabricOptimizer] ] @@ -95,7 +100,7 @@ class LightningModule( ): # Below is for property support of JIT # since none of these are important when using JIT, we are going to ignore them. - __jit_unused_properties__: list[str] = ( + __jit_unused_properties__: List[str] = ( [ "example_input_array", "on_gpu", @@ -127,19 +132,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._trainer: Optional[pl.Trainer] = None # attributes that can be set by user - self._example_input_array: Optional[Union[Tensor, tuple, dict]] = None + self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None self._automatic_optimization: bool = True self._strict_loading: Optional[bool] = None # attributes used internally self._current_fx_name: Optional[str] = None - self._param_requires_grad_state: dict[str, bool] = {} - self._metric_attributes: Optional[dict[int, str]] = None - self._compiler_ctx: Optional[dict[str, Any]] = None + self._param_requires_grad_state: Dict[str, bool] = {} + self._metric_attributes: Optional[Dict[int, str]] = None + self._compiler_ctx: Optional[Dict[str, Any]] = None # attributes only used when using fabric self._fabric: Optional[lf.Fabric] = None - self._fabric_optimizers: list[_FabricOptimizer] = [] + self._fabric_optimizers: List[_FabricOptimizer] = [] # access to device mesh in `conigure_model()` hook self._device_mesh: Optional[DeviceMesh] = None @@ -147,10 +152,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @overload def optimizers( self, use_pl_optimizer: Literal[True] = True - ) -> Union[LightningOptimizer, list[LightningOptimizer]]: ... + ) -> Union[LightningOptimizer, List[LightningOptimizer]]: ... @overload - def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, list[Optimizer]]: ... + def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: ... @overload def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ... @@ -185,7 +190,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: # multiple opts return opts - def lr_schedulers(self) -> Union[None, list[LRSchedulerPLType], LRSchedulerPLType]: + def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. Returns: @@ -197,7 +202,7 @@ def lr_schedulers(self) -> Union[None, list[LRSchedulerPLType], LRSchedulerPLTyp return None # ignore other keys "interval", "frequency", etc. - lr_schedulers: list[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] + lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] # single scheduler if len(lr_schedulers) == 1: @@ -235,7 +240,7 @@ def fabric(self, fabric: Optional["lf.Fabric"]) -> None: self._fabric = fabric @property - def example_input_array(self) -> Optional[Union[Tensor, tuple, dict]]: + def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: """The example input array is a specification of what the module can consume in the :meth:`forward` method. The return type is interpreted as follows: @@ -250,7 +255,7 @@ def example_input_array(self) -> Optional[Union[Tensor, tuple, dict]]: return self._example_input_array @example_input_array.setter - def example_input_array(self, example: Optional[Union[Tensor, tuple, dict]]) -> None: + def example_input_array(self, example: Optional[Union[Tensor, Tuple, Dict]]) -> None: self._example_input_array = example @property @@ -313,7 +318,7 @@ def logger(self) -> Optional[Union[Logger, FabricLogger]]: return self._trainer.logger if self._trainer is not None else None @property - def loggers(self) -> Union[list[Logger], list[FabricLogger]]: + def loggers(self) -> Union[List[Logger], List[FabricLogger]]: """Reference to the list of loggers in the Trainer.""" if self._fabric is not None: return self._fabric.loggers @@ -594,7 +599,7 @@ def log_dict( if self._fabric is not None: return self._log_dict_through_fabric(dictionary=dictionary, logger=logger) - kwargs: dict[str, bool] = {} + kwargs: Dict[str, bool] = {} if isinstance(dictionary, MetricCollection): kwargs["keep_base"] = False @@ -660,8 +665,8 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor return value def all_gather( - self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, dict, list, tuple]: + self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[Tensor, Dict, List, Tuple]: r"""Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -1412,7 +1417,7 @@ def to_torchscript( method: Optional[str] = "script", example_inputs: Optional[Any] = None, **kwargs: Any, - ) -> Union[ScriptModule, dict[str, ScriptModule]]: + ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing, please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that are @@ -1589,7 +1594,7 @@ def load_from_checkpoint( return cast(Self, loaded) @override - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> Dict[str, Any]: state = dict(self.__dict__) state["_trainer"] = None return state diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index 46126e2..777dca0 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator from contextlib import contextmanager from dataclasses import fields -from typing import Any, Callable, Optional, Union, overload +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, overload from weakref import proxy import torch @@ -173,7 +172,7 @@ def __getattr__(self, item: Any) -> Any: def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", -) -> tuple[list[Optimizer], list[LRSchedulerConfig]]: +) -> Tuple[List[Optimizer], List[LRSchedulerConfig]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" from lightning.pytorch.trainer import call @@ -198,8 +197,8 @@ def _init_optimizers_and_lr_schedulers( def _configure_optimizers( - optim_conf: Union[dict[str, Any], list, Optimizer, tuple], -) -> tuple[list, list, Optional[str]]: + optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple], +) -> Tuple[List, List, Optional[str]]: optimizers, lr_schedulers = [], [] monitor = None @@ -247,7 +246,7 @@ def _configure_optimizers( return optimizers, lr_schedulers, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> list[LRSchedulerConfig]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization.""" lr_scheduler_configs = [] for scheduler in schedulers: @@ -302,7 +301,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] return lr_scheduler_configs -def _configure_schedulers_manual_opt(schedulers: list) -> list[LRSchedulerConfig]: +def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual optimization.""" lr_scheduler_configs = [] @@ -327,7 +326,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> list[LRSchedulerConfig return lr_scheduler_configs -def _validate_scheduler_api(lr_scheduler_configs: list[LRSchedulerConfig], model: "pl.LightningModule") -> None: +def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: for config in lr_scheduler_configs: scheduler = config.scheduler if not isinstance(scheduler, _Stateful): @@ -348,7 +347,7 @@ def _validate_scheduler_api(lr_scheduler_configs: list[LRSchedulerConfig], model ) -def _validate_multiple_optimizers_support(optimizers: list[Optimizer], model: "pl.LightningModule") -> None: +def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "pl.LightningModule") -> None: if is_param_in_hook_signature(model.training_step, "optimizer_idx", explicit=True): raise RuntimeError( "Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx`" @@ -363,7 +362,7 @@ def _validate_multiple_optimizers_support(optimizers: list[Optimizer], model: "p ) -def _validate_optimizers_attached(optimizers: list[Optimizer], lr_scheduler_configs: list[LRSchedulerConfig]) -> None: +def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: for config in lr_scheduler_configs: if config.scheduler.optimizer not in optimizers: raise MisconfigurationException( @@ -371,7 +370,7 @@ def _validate_optimizers_attached(optimizers: list[Optimizer], lr_scheduler_conf ) -def _validate_optim_conf(optim_conf: dict[str, Any]) -> None: +def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: valid_keys = {"optimizer", "lr_scheduler", "monitor"} extra_keys = optim_conf.keys() - valid_keys if extra_keys: @@ -388,15 +387,15 @@ def __init__(self) -> None: super().__init__([torch.zeros(1)], {}) @override - def add_param_group(self, param_group: dict[Any, Any]) -> None: + def add_param_group(self, param_group: Dict[Any, Any]) -> None: pass # Do Nothing @override - def load_state_dict(self, state_dict: dict[Any, Any]) -> None: + def load_state_dict(self, state_dict: Dict[Any, Any]) -> None: pass # Do Nothing @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return {} # Return Empty @overload diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 09d888c..521192f 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -22,7 +22,7 @@ from copy import deepcopy from enum import Enum from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union from warnings import warn import torch @@ -51,7 +51,7 @@ def _load_from_checkpoint( - cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, @@ -115,8 +115,8 @@ def _default_map_location(storage: "UntypedStorage", location: str) -> Optional[ def _load_state( - cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], - checkpoint: dict[str, Any], + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], + checkpoint: Dict[str, Any], strict: Optional[bool] = None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: @@ -200,8 +200,8 @@ def _load_state( def _convert_loaded_hparams( - model_args: dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None -) -> dict[str, Any]: + model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None +) -> Dict[str, Any]: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: @@ -243,7 +243,7 @@ def update_hparams(hparams: dict, updates: dict) -> None: hparams.update({k: v}) -def load_hparams_from_tags_csv(tags_csv: _PATH) -> dict[str, Any]: +def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') @@ -281,7 +281,7 @@ def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) - writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> dict[str, Any]: +def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]: """Load hparams from a file. Args: diff --git a/src/lightning/pytorch/demos/__init__.py b/src/lightning/pytorch/demos/__init__.py index 1e03e2f..fa91d7c 100644 --- a/src/lightning/pytorch/demos/__init__.py +++ b/src/lightning/pytorch/demos/__init__.py @@ -1,15 +1,2 @@ -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, DemoModel -from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM -from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 - -__all__ = [ - "LightningLSTM", - "SequenceSampler", - "SimpleLSTM", - "LightningTransformer", - "Transformer", - "WikiText2", - "BoringModel", - "BoringDataModule", - "DemoModel", -] +from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM # noqa: F401 +from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 # noqa: F401 diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index 589524e..fd26602 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator -from typing import Any, Optional +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -36,7 +35,7 @@ def __init__(self, size: int, length: int): self.len = length self.data = torch.randn(length, size) - def __getitem__(self, index: int) -> dict[str, Tensor]: + def __getitem__(self, index: int) -> Dict[str, Tensor]: a = self.data[index] b = a + 2 return {"a": a, "b": b} @@ -135,7 +134,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: return {"y": self.step(batch)} - def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[LRScheduler]]: + def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]: optimizer = torch.optim.SGD(self.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] diff --git a/src/lightning/pytorch/demos/lstm.py b/src/lightning/pytorch/demos/lstm.py index 9432dd9..672b61a 100644 --- a/src/lightning/pytorch/demos/lstm.py +++ b/src/lightning/pytorch/demos/lstm.py @@ -5,8 +5,7 @@ """ -from collections.abc import Iterator, Sized -from typing import Optional +from typing import Iterator, List, Optional, Sized, Tuple import torch import torch.nn as nn @@ -38,14 +37,14 @@ def init_weights(self) -> None: nn.init.zeros_(self.decoder.bias) nn.init.uniform_(self.decoder.weight, -0.1, 0.1) - def forward(self, input: Tensor, hidden: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + def forward(self, input: Tensor, hidden: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) decoded = self.decoder(output).view(-1, self.vocab_size) return F.log_softmax(decoded, dim=1), hidden - def init_hidden(self, batch_size: int) -> tuple[Tensor, Tensor]: + def init_hidden(self, batch_size: int) -> Tuple[Tensor, Tensor]: weight = next(self.parameters()) return ( weight.new_zeros(self.nlayers, batch_size, self.nhid), @@ -53,14 +52,14 @@ def init_hidden(self, batch_size: int) -> tuple[Tensor, Tensor]: ) -class SequenceSampler(Sampler[list[int]]): +class SequenceSampler(Sampler[List[int]]): def __init__(self, dataset: Sized, batch_size: int) -> None: super().__init__() self.dataset = dataset self.batch_size = batch_size self.chunk_size = len(self.dataset) // self.batch_size - def __iter__(self) -> Iterator[list[int]]: + def __iter__(self) -> Iterator[List[int]]: n = len(self.dataset) for i in range(self.chunk_size): yield list(range(i, n - (n % self.batch_size), self.chunk_size)) @@ -73,12 +72,12 @@ class LightningLSTM(LightningModule): def __init__(self, vocab_size: int = 33278): super().__init__() self.model = SimpleLSTM(vocab_size=vocab_size) - self.hidden: Optional[tuple[Tensor, Tensor]] = None + self.hidden: Optional[Tuple[Tensor, Tensor]] = None def on_train_epoch_end(self) -> None: self.hidden = None - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: input, target = batch if self.hidden is None: self.hidden = self.model.init_hidden(input.size(0)) diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 73f46d4..992527a 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -16,8 +16,7 @@ import random import time import urllib -from collections.abc import Sized -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Sized, Tuple, Union from urllib.error import HTTPError from warnings import warn @@ -64,7 +63,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> tuple[Tensor, int]: + def __getitem__(self, idx: int) -> Tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) @@ -100,7 +99,7 @@ def _download(self, data_folder: str) -> None: urllib.request.urlretrieve(url, fpath) # noqa: S310 @staticmethod - def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> tuple[Tensor, Tensor]: + def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> Tuple[Tensor, Tensor]: """Resolving loading from the same time from multiple concurrent processes.""" res, exception = None, None assert trials, "at least some trial has to be set" diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index eca86b4..58cf30c 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -8,7 +8,7 @@ import math import os from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -119,7 +119,7 @@ def vocab_size(self) -> int: def __len__(self) -> int: return len(self.data) // self.block_size - 1 - def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: + def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: start = index * self.block_size end = start + self.block_size inputs = self.data[start:end] @@ -143,8 +143,8 @@ def download(destination: Path) -> None: class Dictionary: def __init__(self) -> None: - self.word2idx: dict[str, int] = {} - self.idx2word: list[str] = [] + self.word2idx: Dict[str, int] = {} + self.idx2word: List[str] = [] def add_word(self, word: str) -> int: if word not in self.word2idx: @@ -156,7 +156,7 @@ def __len__(self) -> int: return len(self.idx2word) -def tokenize(path: Path) -> tuple[Tensor, Dictionary]: +def tokenize(path: Path) -> Tuple[Tensor, Dictionary]: dictionary = Dictionary() assert os.path.exists(path) @@ -169,10 +169,10 @@ def tokenize(path: Path) -> tuple[Tensor, Dictionary]: # Tokenize file content with open(path, encoding="utf8") as f: - idss: list[Tensor] = [] + idss: List[Tensor] = [] for line in f: words = line.split() + [""] - ids: list[int] = [] + ids: List[int] = [] for word in words: ids.append(dictionary.word2idx[word]) idss.append(torch.tensor(ids).type(torch.int64)) @@ -188,7 +188,7 @@ def __init__(self, vocab_size: int = 33278) -> None: def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return self.model(inputs, target) - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: inputs, target = batch output = self(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 9c05317..277af5c 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -19,8 +19,7 @@ import logging import os from argparse import Namespace -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -306,7 +305,7 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) self.experiment.log_parameters(params) @@ -411,7 +410,7 @@ def version(self) -> str: return self._future_experiment_key - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # Save the experiment id in case an experiment object already exists, diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 8606264..caca0c1 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -21,7 +21,7 @@ import os from argparse import Namespace -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -52,9 +52,9 @@ class ExperimentWriter(_FabricExperimentWriter): def __init__(self, log_dir: str) -> None: super().__init__(log_dir=log_dir) - self.hparams: dict[str, Any] = {} + self.hparams: Dict[str, Any] = {} - def log_hparams(self, params: dict[str, Any]) -> None: + def log_hparams(self, params: Dict[str, Any]) -> None: """Record hparams.""" self.hparams.update(params) @@ -144,7 +144,7 @@ def save_dir(self) -> str: @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) self.experiment.log_hparams(params) diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 668fe39..40e8ed8 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -18,8 +18,7 @@ import statistics from abc import ABC from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Mapping, Optional, Sequence from typing_extensions import override @@ -102,7 +101,7 @@ def merge_dicts( # pragma: no cover dicts: Sequence[Mapping], agg_key_funcs: Optional[Mapping] = None, default_func: Callable[[Sequence[float]], float] = statistics.mean, -) -> dict: +) -> Dict: """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. Args: @@ -138,7 +137,7 @@ def merge_dicts( # pragma: no cover """ agg_key_funcs = agg_key_funcs or {} keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) - d_out: dict = defaultdict(dict) + d_out: Dict = defaultdict(dict) for k in keys: fn = agg_key_funcs.get(k) values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None] diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index e3d9998..ec990b6 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -21,10 +21,9 @@ import re import tempfile from argparse import Namespace -from collections.abc import Mapping from pathlib import Path from time import time -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional, Union import yaml from lightning_utilities.core.imports import RequirementCache @@ -118,7 +117,7 @@ def __init__( experiment_name: str = "lightning_logs", run_name: Optional[str] = None, tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), - tags: Optional[dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", @@ -141,7 +140,7 @@ def __init__( self._run_id = run_id self.tags = tags self._log_model = log_model - self._logged_model_time: dict[str, float] = {} + self._logged_model_time: Dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None self._prefix = prefix self._artifact_location = artifact_location @@ -228,7 +227,7 @@ def experiment_id(self) -> Optional[str]: @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) @@ -250,7 +249,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) from mlflow.entities import Metric metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - metrics_list: list[Metric] = [] + metrics_list: List[Metric] = [] timestamp_ms = int(time() * 1000) for k, v in metrics.items(): diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index a363f58..691dbe0 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -20,9 +20,8 @@ import logging import os from argparse import Namespace -from collections.abc import Generator from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -287,8 +286,8 @@ def _retrieve_run_data(self) -> None: self._run_name = "offline-name" @property - def _neptune_init_args(self) -> dict: - args: dict = {} + def _neptune_init_args(self) -> Dict: + args: Dict = {} # Backward compatibility in case of previous version retrieval with contextlib.suppress(AttributeError): args = self._neptune_run_kwargs @@ -338,13 +337,13 @@ def _verify_input_arguments( " parameters." ) - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # Run instance can't be pickled state["_run_instance"] = None return state - def __setstate__(self, state: dict[str, Any]) -> None: + def __setstate__(self, state: Dict[str, Any]) -> None: import neptune self.__dict__ = state @@ -396,7 +395,7 @@ def run(self) -> "Run": @override @rank_zero_only @_catch_inactive - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: r"""Log hyperparameters to the run. Hyperparameters will be logged under the "/hyperparams" namespace. @@ -444,7 +443,7 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: @override @rank_zero_only @_catch_inactive - def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -564,16 +563,16 @@ def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> st return model_path.replace(os.sep, "/") @classmethod - def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]: + def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]: """Returns all paths to properties which were already logged in `namespace`""" - structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR) + structure_keys: List[str] = namespace.split(cls.LOGGER_JOIN_CHAR) for key in structure_keys: exp_structure = exp_structure[key] uploaded_models_dict = exp_structure return set(cls._dict_paths(uploaded_models_dict)) @classmethod - def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator: + def _dict_paths(cls, d: Dict[str, Any], path_in_build: Optional[str] = None) -> Generator: for k, v in d.items(): path = f"{path_in_build}/{k}" if path_in_build is not None else k if not isinstance(v, dict): diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index e70c892..88e026f 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -18,7 +18,7 @@ import os from argparse import Namespace -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from torch import Tensor from typing_extensions import override @@ -108,7 +108,7 @@ def __init__( f"{str(_TENSORBOARD_AVAILABLE)}" ) self._log_graph = log_graph and _TENSORBOARD_AVAILABLE - self.hparams: Union[dict[str, Any], Namespace] = {} + self.hparams: Union[Dict[str, Any], Namespace] = {} @property @override @@ -153,7 +153,7 @@ def save_dir(self) -> str: @override @rank_zero_only def log_hyperparams( - self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None + self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index c763071..2ff9cbd 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -14,7 +14,7 @@ """Utilities for loggers.""" from pathlib import Path -from typing import Any, Union +from typing import Any, List, Tuple, Union from torch import Tensor @@ -22,14 +22,14 @@ from lightning.pytorch.callbacks import Checkpoint -def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]: +def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: if len(loggers) == 1: return loggers[0].version # Concatenate versions together, removing duplicates and preserving order return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) -def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> list[tuple[float, str, float, str]]: +def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> List[Tuple[float, str, float, str]]: """Return the checkpoints to be logged. Args: diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 2429748..20f8d02 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -18,9 +18,8 @@ import os from argparse import Namespace -from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union import torch.nn as nn from lightning_utilities.core.imports import RequirementCache @@ -321,7 +320,7 @@ def __init__( self._log_model = log_model self._prefix = prefix self._experiment = experiment - self._logged_model_time: dict[str, float] = {} + self._logged_model_time: Dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None # paths are processed as strings @@ -333,7 +332,7 @@ def __init__( project = project or os.environ.get("WANDB_PROJECT", "lightning_logs") # set wandb init arguments - self._wandb_init: dict[str, Any] = { + self._wandb_init: Dict[str, Any] = { "name": name, "project": project, "dir": save_dir or dir, @@ -349,7 +348,7 @@ def __init__( self._id = self._wandb_init.get("id") self._checkpoint_name = checkpoint_name - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> Dict[str, Any]: import wandb # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called. @@ -422,7 +421,7 @@ def watch( @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _sanitize_callable_params(params) params = _convert_json_serializable(params) @@ -443,8 +442,8 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) def log_table( self, key: str, - columns: Optional[list[str]] = None, - data: Optional[list[list[Any]]] = None, + columns: Optional[List[str]] = None, + data: Optional[List[List[Any]]] = None, dataframe: Any = None, step: Optional[int] = None, ) -> None: @@ -462,8 +461,8 @@ def log_table( def log_text( self, key: str, - columns: Optional[list[str]] = None, - data: Optional[list[list[str]]] = None, + columns: Optional[List[str]] = None, + data: Optional[List[List[str]]] = None, dataframe: Any = None, step: Optional[int] = None, ) -> None: @@ -476,7 +475,7 @@ def log_text( self.log_table(key, columns, data, dataframe, step) @rank_zero_only - def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). @@ -496,7 +495,7 @@ def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_audio(self, key: str, audios: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: r"""Log audios (numpy arrays, or file paths). Args: @@ -522,7 +521,7 @@ def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_video(self, key: str, videos: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: """Log videos (numpy arrays, or file paths). Args: diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index d007466..78573c4 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,9 +15,8 @@ import shutil import sys from collections import ChainMap, OrderedDict, defaultdict -from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -68,17 +67,17 @@ def __init__( self.verbose = verbose self.inference_mode = inference_mode self.batch_progress = _BatchProgress() # across dataloaders - self._max_batches: list[Union[int, float]] = [] + self._max_batches: List[Union[int, float]] = [] self._results = _ResultCollection(training=False) - self._logged_outputs: list[_OUT_DICT] = [] + self._logged_outputs: List[_OUT_DICT] = [] self._has_run: bool = False self._trainer_fn = trainer_fn self._stage = stage self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None - self._seen_batches_per_dataloader: defaultdict[int, int] = defaultdict(int) + self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) self._last_val_dl_reload_epoch = float("-inf") self._module_mode = _ModuleMode() self._restart_stage = RestartStage.NONE @@ -91,7 +90,7 @@ def num_dataloaders(self) -> int: return len(combined_loader.flattened) @property - def max_batches(self) -> list[Union[int, float]]: + def max_batches(self) -> List[Union[int, float]]: """The max number of batches to run per dataloader.""" max_batches = self._max_batches if not self.trainer.sanity_checking: @@ -115,7 +114,7 @@ def _is_sequential(self) -> bool: return self._combined_loader._mode == "sequential" @_no_grad_context - def run(self) -> list[_OUT_DICT]: + def run(self) -> List[_OUT_DICT]: self.setup_data() if self.skip: return [] @@ -281,7 +280,7 @@ def on_run_start(self) -> None: self._on_evaluation_start() self._on_evaluation_epoch_start() - def on_run_end(self) -> list[_OUT_DICT]: + def on_run_end(self) -> List[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` self.trainer._logger_connector.epoch_end_reached() @@ -509,7 +508,7 @@ def _verify_dataloader_idx_requirement(self) -> None: ) @staticmethod - def _get_keys(data: dict) -> Iterable[tuple[str, ...]]: + def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: for k, v in data.items(): if isinstance(v, dict): for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys): @@ -528,7 +527,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: return _EvaluationLoop._find_value(result, rest) @staticmethod - def _print_results(results: list[_OUT_DICT], stage: str) -> None: + def _print_results(results: List[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys} @@ -545,7 +544,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None: term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2)) - rows: list[list[Any]] = [[] for _ in metrics_paths] + rows: List[List[Any]] = [[] for _ in metrics_paths] for result in results: for metric, row in zip(metrics_paths, rows): diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 92ec95a..e699321 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator -from typing import Any, Optional +from typing import Any, Iterator, List, Optional from typing_extensions import override @@ -98,7 +97,7 @@ def __init__(self, prefetch_batches: int = 1) -> None: if prefetch_batches < 0: raise ValueError("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches - self.batches: list[Any] = [] + self.batches: List[Any] = [] @override def __iter__(self) -> "_PrefetchDataFetcher": diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 31d6724..e20088a 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from typing_extensions import override @@ -104,7 +104,7 @@ def __init__( self._data_source = _DataLoaderSource(None, "train_dataloader") self._combined_loader: Optional[CombinedLoader] = None - self._combined_loader_states_to_load: list[dict[str, Any]] = [] + self._combined_loader_states_to_load: List[Dict[str, Any]] = [] self._data_fetcher: Optional[_DataFetcher] = None self._last_train_dl_reload_epoch = float("-inf") self._restart_stage = RestartStage.NONE @@ -504,14 +504,14 @@ def teardown(self) -> None: self.epoch_loop.teardown() @override - def on_save_checkpoint(self) -> dict: + def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()): state_dict["combined_loader"] = loader_states return state_dict @override - def on_load_checkpoint(self, state_dict: dict) -> None: + def on_load_checkpoint(self, state_dict: Dict) -> None: self._combined_loader_states_to_load = state_dict.get("combined_loader", []) super().on_load_checkpoint(state_dict) diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index daad309..111377a 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Dict, Optional import lightning.pytorch as pl from lightning.pytorch.loops.progress import _BaseProgress @@ -41,7 +41,7 @@ def restarting(self, restarting: bool) -> None: def reset_restart_stage(self) -> None: pass - def on_save_checkpoint(self) -> dict: + def on_save_checkpoint(self) -> Dict: """Called when saving a model checkpoint, use to persist loop state. Returns: @@ -50,10 +50,10 @@ def on_save_checkpoint(self) -> dict: """ return {} - def on_load_checkpoint(self, state_dict: dict) -> None: + def on_load_checkpoint(self, state_dict: Dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" - def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> dict: + def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict: """The state dict is determined by the state and progress of this loop and all its children. Args: @@ -77,7 +77,7 @@ def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> di def load_state_dict( self, - state_dict: dict, + state_dict: Dict, prefix: str = "", ) -> None: """Loads the state of this loop and all its children.""" @@ -88,7 +88,7 @@ def load_state_dict( self.restarting = True self._loaded_from_state_dict = True - def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: + def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: for k, v in self.__dict__.items(): key = prefix + k if key not in state_dict: diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index e19b576..2ce6aca 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict -from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict import torch from torch import Tensor @@ -48,7 +46,7 @@ class ClosureResult(OutputResult): closure_loss: Optional[Tensor] loss: Optional[Tensor] = field(init=False, default=None) - extra: dict[str, Any] = field(default_factory=dict) + extra: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: self._clone_loss() @@ -85,7 +83,7 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize: return cls(closure_loss, extra=extra) @override - def asdict(self) -> dict[str, Any]: + def asdict(self) -> Dict[str, Any]: return {"loss": self.loss, **self.extra} @@ -147,7 +145,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: return self._result.loss -_OUTPUTS_TYPE = dict[str, Any] +_OUTPUTS_TYPE = Dict[str, Any] class _AutomaticOptimization(_Loop): diff --git a/src/lightning/pytorch/loops/optimization/closure.py b/src/lightning/pytorch/loops/optimization/closure.py index e45262a..4b55016 100644 --- a/src/lightning/pytorch/loops/optimization/closure.py +++ b/src/lightning/pytorch/loops/optimization/closure.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Dict, Generic, Optional, TypeVar from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -22,7 +22,7 @@ @dataclass class OutputResult: - def asdict(self) -> dict[str, Any]: + def asdict(self) -> Dict[str, Any]: raise NotImplementedError diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index e1aabcb..d8a4f19 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import suppress from dataclasses import dataclass, field -from typing import Any +from typing import Any, Dict from torch import Tensor from typing_extensions import override @@ -40,7 +40,7 @@ class ManualResult(OutputResult): """ - extra: dict[str, Any] = field(default_factory=dict) + extra: Dict[str, Any] = field(default_factory=dict) @classmethod def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult": @@ -61,11 +61,11 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "Manual return cls(extra=extra) @override - def asdict(self) -> dict[str, Any]: + def asdict(self) -> Dict[str, Any]: return self.extra -_OUTPUTS_TYPE = dict[str, Any] +_OUTPUTS_TYPE = Dict[str, Any] class _ManualOptimization(_Loop): diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 7044cce..9002e62 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from collections.abc import Iterator -from typing import Any, Optional, Union +from typing import Any, Iterator, List, Optional, Union import torch from lightning_utilities import WarningCache @@ -51,17 +50,17 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: super().__init__(trainer) self.inference_mode = inference_mode # dataloaders x batches x samples. used by PredictionWriter - self.epoch_batch_indices: list[list[list[int]]] = [] - self.current_batch_indices: list[int] = [] # used by PredictionWriter + self.epoch_batch_indices: List[List[List[int]]] = [] + self.current_batch_indices: List[int] = [] # used by PredictionWriter self.batch_progress = _Progress() # across dataloaders - self.max_batches: list[Union[int, float]] = [] + self.max_batches: List[Union[int, float]] = [] self._warning_cache = WarningCache() self._data_source = _DataLoaderSource(None, "predict_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None self._results = None # for `trainer._results` access - self._predictions: list[list[Any]] = [] # dataloaders x batches + self._predictions: List[List[Any]] = [] # dataloaders x batches self._return_predictions = False self._module_mode = _ModuleMode() @@ -83,7 +82,7 @@ def return_predictions(self, return_predictions: Optional[bool] = None) -> None: self._return_predictions = return_supported if return_predictions is None else return_predictions @property - def predictions(self) -> list[Any]: + def predictions(self) -> List[Any]: """The cached predictions.""" if self._predictions == []: return self._predictions @@ -298,7 +297,7 @@ def _build_step_args_from_hook_kwargs(self, hook_kwargs: OrderedDict, step_hook_ kwargs.pop("batch_idx", None) return tuple(kwargs.values()) - def _get_batch_indices(self, dataloader: object) -> list[list[int]]: # batches x samples + def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`.""" batch_sampler = getattr(dataloader, "batch_sampler", None) diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 42e5de6..6880b24 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict, dataclass, field +from typing import Type from typing_extensions import override @@ -173,7 +174,7 @@ def increment_completed(self) -> None: self.current.completed += 1 @classmethod - def from_defaults(cls, tracker_cls: type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": + def from_defaults(cls, tracker_cls: Type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 7cdf788..1c749de 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -14,7 +14,7 @@ import math from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from typing_extensions import override @@ -390,13 +390,13 @@ def teardown(self) -> None: self.val_loop.teardown() @override - def on_save_checkpoint(self) -> dict: + def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() state_dict["_batches_that_stepped"] = self._batches_that_stepped return state_dict @override - def on_load_checkpoint(self, state_dict: dict) -> None: + def on_load_checkpoint(self, state_dict: Dict) -> None: self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) def _accumulated_batches_reached(self) -> bool: diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 2aaf877..99ea5c4 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from collections.abc import Generator -from contextlib import AbstractContextManager, contextmanager -from typing import Any, Callable, Optional +from contextlib import contextmanager +from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type import torch import torch.distributed as dist @@ -53,7 +52,7 @@ def _parse_loop_limits( min_epochs: Optional[int], max_epochs: Optional[int], trainer: "pl.Trainer", -) -> tuple[int, int]: +) -> Tuple[int, int]: """This utility computes the default values for the minimum and maximum number of steps and epochs given the values the user has selected. @@ -160,7 +159,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: raise TypeError(f"`{type(self).__name__}` needs to be a Loop.") if not hasattr(self, "inference_mode"): raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined") - context_manager: type[AbstractContextManager] + context_manager: Type[ContextManager] if _distributed_is_initialized() and dist.get_backend() == "gloo": # gloo backend does not work properly. # https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110 @@ -182,7 +181,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: def _verify_dataloader_idx_requirement( - hooks: tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" + hooks: Tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" ) -> None: for hook in hooks: fx = getattr(pl_module, hook) diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 196008b..e4b6528 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from collections.abc import Iterable, Iterator, Sized -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Union, cast import torch from torch import Tensor @@ -28,7 +27,7 @@ def _find_tensors( obj: Union[Tensor, list, tuple, dict, Any], -) -> Union[list[Tensor], itertools.chain]: # pragma: no-cover +) -> Union[List[Tensor], itertools.chain]: # pragma: no-cover """Recursively find all tensors contained in the specified object.""" if isinstance(obj, Tensor): return [obj] @@ -202,7 +201,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: assert self.num_samples >= 1 or self.total_size == 0 @override - def __iter__(self) -> Iterator[list[int]]: + def __iter__(self) -> Iterator[List[int]]: if not isinstance(self.dataset, Sized): raise TypeError("The given dataset must implement the `__len__` method.") if self.shuffle: @@ -239,7 +238,7 @@ class _IndexBatchSamplerWrapper: def __init__(self, batch_sampler: _SizedIterable) -> None: # do not call super().__init__() on purpose - self.seen_batch_indices: list[list[int]] = [] + self.seen_batch_indices: List[List[int]] = [] self.__dict__ = { k: v @@ -247,9 +246,9 @@ def __init__(self, batch_sampler: _SizedIterable) -> None: if k not in ("__next__", "__iter__", "__len__", "__getstate__") } self._batch_sampler = batch_sampler - self._iterator: Optional[Iterator[list[int]]] = None + self._iterator: Optional[Iterator[List[int]]] = None - def __next__(self) -> list[int]: + def __next__(self) -> List[int]: assert self._iterator is not None batch = next(self._iterator) self.seen_batch_indices.append(batch) @@ -263,7 +262,7 @@ def __iter__(self) -> Self: def __len__(self) -> int: return len(self._batch_sampler) - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state["_iterator"] = None # cannot pickle 'generator' object return state diff --git a/src/lightning/pytorch/plugins/io/wrapper.py b/src/lightning/pytorch/plugins/io/wrapper.py index 548bc1f..6e918b8 100644 --- a/src/lightning/pytorch/plugins/io/wrapper.py +++ b/src/lightning/pytorch/plugins/io/wrapper.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Dict, Optional from typing_extensions import override @@ -66,7 +66,7 @@ def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None: self.checkpoint_io.remove_checkpoint(*args, **kwargs) @override - def load_checkpoint(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: """Uses the base ``checkpoint_io`` to load the checkpoint.""" assert self.checkpoint_io is not None return self.checkpoint_io.load_checkpoint(*args, **kwargs) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792a..e63ccd6 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -9,9 +9,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Dict, Generator, Literal, Optional, Union import torch from torch import Tensor @@ -122,12 +121,12 @@ def forward_context(self) -> Generator[None, None, None]: yield @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 9225e3b..e1e9028 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -80,13 +80,13 @@ def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index efa1aa0..20f493b 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator -from contextlib import AbstractContextManager, contextmanager -from typing import Any, Literal +from contextlib import contextmanager +from typing import Any, ContextManager, Generator, Literal import torch import torch.nn as nn @@ -38,11 +37,11 @@ def convert_module(self, module: nn.Module) -> nn.Module: return module.double() @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(torch.float64) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 7029497..e6c6849 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Optional import torch from lightning_utilities import apply_to_collection @@ -110,15 +109,15 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": ) @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> AbstractContextManager: + def forward_context(self) -> ContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return _DtypeContextManager(self._desired_input_dtype) @@ -167,12 +166,12 @@ def optimizer_step( # type: ignore[override] return closure_result @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index fe9deb4..22dc29b 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator -from contextlib import AbstractContextManager, contextmanager -from typing import Any, Literal +from contextlib import contextmanager +from typing import Any, ContextManager, Generator, Literal import torch from lightning_utilities import apply_to_collection @@ -44,11 +43,11 @@ def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> AbstractContextManager: + def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> AbstractContextManager: + def module_init_context(self) -> ContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 327fb2d..51bdddb 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections.abc import Generator from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Generator, List, Optional, Tuple, Union import torch from torch import Tensor @@ -38,8 +37,8 @@ class Precision(FabricPrecision, CheckpointHooks): """ def connect( - self, model: Module, optimizers: list[Optimizer], lr_schedulers: list[Any] - ) -> tuple[Module, list[Optimizer], list[Any]]: + self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] + ) -> Tuple[Module, List[Optimizer], List[Any]]: """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 41681fb..467b471 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -20,7 +20,7 @@ import pstats import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Tuple, Union from typing_extensions import override @@ -66,7 +66,7 @@ def __init__( If you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: dict[str, cProfile.Profile] = {} + self.profiled_actions: Dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction self.dump_stats = dump_stats @@ -89,10 +89,9 @@ def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None: dst_fs = get_filesystem(dst_filepath) dst_fs.mkdirs(self.dirpath, exist_ok=True) # temporarily save to local since pstats can only dump into a local file - with ( - tempfile.TemporaryDirectory(prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()) as tmp_dir, - dst_fs.open(dst_filepath, "wb") as dst_file, - ): + with tempfile.TemporaryDirectory( + prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd() + ) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file: src_filepath = os.path.join(tmp_dir, "tmp.prof") profile.dump_stats(src_filepath) src_fs = get_filesystem(src_filepath) @@ -116,7 +115,7 @@ def teardown(self, stage: Optional[str]) -> None: super().teardown(stage=stage) self.profiled_actions = {} - def __reduce__(self) -> tuple: + def __reduce__(self) -> Tuple: # avoids `TypeError: cannot pickle 'cProfile.Profile' object` return ( self.__class__, diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index a09b703..fb44832 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -16,10 +16,9 @@ import logging import os from abc import ABC, abstractmethod -from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Optional, TextIO, Union +from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union from lightning.fabric.utilities.cloud_io import get_filesystem @@ -116,7 +115,7 @@ def describe(self) -> None: self._output_file.flush() self.teardown(stage=self._stage) - def _stats_to_str(self, stats: dict[str, str]) -> str: + def _stats_to_str(self, stats: Dict[str, str]) -> str: stage = f"{self._stage.upper()} " if self._stage is not None else "" output = [stage + "Profiler Report"] for action, value in stats.items(): diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index e264d51..a26b3d3 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -16,10 +16,9 @@ import inspect import logging import os -from contextlib import AbstractContextManager from functools import lru_cache, partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Type, Union import torch from torch import Tensor, nn @@ -66,8 +65,8 @@ class RegisterRecordFunction: def __init__(self, model: nn.Module) -> None: self._model = model - self._records: dict[str, record_function] = {} - self._handles: dict[str, list[RemovableHandle]] = {} + self._records: Dict[str, record_function] = {} + self._handles: Dict[str, List[RemovableHandle]] = {} def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: # Add [pl][module] in name for pytorch profiler to recognize @@ -240,7 +239,7 @@ def __init__( row_limit: int = 20, sort_by_key: Optional[str] = None, record_module_names: bool = True, - table_kwargs: Optional[dict[str, Any]] = None, + table_kwargs: Optional[Dict[str, Any]] = None, **profiler_kwargs: Any, ) -> None: r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of @@ -306,8 +305,8 @@ def __init__( self.function_events: Optional[EventList] = None self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[AbstractContextManager] = None - self._recording_map: dict[str, record_function] = {} + self._parent_profiler: Optional[ContextManager] = None + self._recording_map: Dict[str, record_function] = {} self._start_action_name: Optional[str] = None self._schedule: Optional[ScheduleWrapper] = None @@ -401,8 +400,8 @@ def _default_schedule() -> Optional[Callable]: return torch.profiler.schedule(wait=1, warmup=1, active=3) return None - def _default_activities(self) -> list["ProfilerActivity"]: - activities: list[ProfilerActivity] = [] + def _default_activities(self) -> List["ProfilerActivity"]: + activities: List[ProfilerActivity] = [] if not _KINETO_AVAILABLE: return activities if _TORCH_GREATER_EQUAL_2_4: @@ -531,7 +530,7 @@ def _create_profilers(self) -> None: torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile ) - def _create_profiler(self, profiler: type[_PROFILER]) -> _PROFILER: + def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: init_parameters = inspect.signature(profiler.__init__).parameters kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} return profiler(**kwargs) diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py index 8a53965..eef7b12 100644 --- a/src/lightning/pytorch/profilers/simple.py +++ b/src/lightning/pytorch/profilers/simple.py @@ -18,7 +18,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from typing_extensions import override @@ -27,10 +27,10 @@ log = logging.getLogger(__name__) -_TABLE_ROW_EXTENDED = tuple[str, float, int, float, float] -_TABLE_DATA_EXTENDED = list[_TABLE_ROW_EXTENDED] -_TABLE_ROW = tuple[str, float, float] -_TABLE_DATA = list[_TABLE_ROW] +_TABLE_ROW_EXTENDED = Tuple[str, float, int, float, float] +_TABLE_DATA_EXTENDED = List[_TABLE_ROW_EXTENDED] +_TABLE_ROW = Tuple[str, float, float] +_TABLE_DATA = List[_TABLE_ROW] class SimpleProfiler(Profiler): @@ -61,8 +61,8 @@ def __init__( if you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.current_actions: dict[str, float] = {} - self.recorded_durations: dict = defaultdict(list) + self.current_actions: Dict[str, float] = {} + self.recorded_durations: Dict = defaultdict(list) self.extended = extended self.start_time = time.monotonic() @@ -81,7 +81,7 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def _make_report_extended(self) -> tuple[_TABLE_DATA_EXTENDED, float, float]: + def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]: total_duration = time.monotonic() - self.start_time report = [] diff --git a/src/lightning/pytorch/profilers/xla.py b/src/lightning/pytorch/profilers/xla.py index 3e810fb..a85f3a1 100644 --- a/src/lightning/pytorch/profilers/xla.py +++ b/src/lightning/pytorch/profilers/xla.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Dict from typing_extensions import override @@ -44,8 +45,8 @@ def __init__(self, port: int = 9012) -> None: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(dirpath=None, filename=None) self.port = port - self._recording_map: dict = {} - self._step_recoding_map: dict = {} + self._recording_map: Dict = {} + self._step_recoding_map: Dict = {} self._start_trace: bool = False @override diff --git a/src/lightning/pytorch/serve/servable_module.py b/src/lightning/pytorch/serve/servable_module.py index ed7a8a9..f715f4b 100644 --- a/src/lightning/pytorch/serve/servable_module.py +++ b/src/lightning/pytorch/serve/servable_module.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable +from typing import Any, Callable, Dict, Tuple import torch from torch import Tensor @@ -56,11 +56,11 @@ def configure_response(self): """ @abstractmethod - def configure_payload(self) -> dict[str, Any]: + def configure_payload(self) -> Dict[str, Any]: """Returns a request payload as a dictionary.""" @abstractmethod - def configure_serialization(self) -> tuple[dict[str, Callable], dict[str, Callable]]: + def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callable]]: """Returns a tuple of dictionaries. The first dictionary contains the name of the ``serve_step`` input variables name as its keys @@ -72,7 +72,7 @@ def configure_serialization(self) -> tuple[dict[str, Callable], dict[str, Callab """ @abstractmethod - def serve_step(self, *args: Tensor, **kwargs: Tensor) -> dict[str, Tensor]: + def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]: r"""Returns the predictions of your model as a dictionary. .. code-block:: python @@ -90,5 +90,5 @@ def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ @abstractmethod - def configure_response(self) -> dict[str, Any]: + def configure_response(self) -> Dict[str, Any]: """Returns a response to validate the server response.""" diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py index dc92625..0acab20 100644 --- a/src/lightning/pytorch/serve/servable_module_validator.py +++ b/src/lightning/pytorch/serve/servable_module_validator.py @@ -2,7 +2,7 @@ import logging import time from multiprocessing import Process -from typing import Any, Literal, Optional +from typing import Any, Dict, Literal, Optional import requests import torch @@ -136,7 +136,7 @@ def successful(self) -> Optional[bool]: return self.resp.status_code == 200 if self.resp else None @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return {"successful": self.successful, "optimization": self.optimization, "server": self.server} @staticmethod @@ -157,7 +157,7 @@ def ping() -> bool: return True @app.post("/serve") - async def serve(payload: dict = Body(...)) -> dict[str, Any]: + async def serve(payload: dict = Body(...)) -> Dict[str, Any]: body = payload["body"] for key, deserializer in deserializers.items(): diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66e..9031b6e 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -14,7 +14,7 @@ import logging from contextlib import nullcontext from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union import torch import torch.distributed @@ -71,7 +71,7 @@ class DDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -133,7 +133,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> dict[str, Any]: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -283,7 +283,7 @@ def configure_ddp(self) -> None: self.model = self._setup_model(self.model) self._register_ddp_hooks() - def determine_ddp_device_ids(self) -> Optional[list[int]]: + def determine_ddp_device_ids(self) -> Optional[List[int]]: if self.root_device.type == "cpu": return None return [self.root_device.index] diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4fa7711..1eaa5ba 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -17,10 +17,9 @@ import os import platform from collections import OrderedDict -from collections.abc import Generator, Mapping from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch from torch.nn import Module @@ -103,9 +102,9 @@ def __init__( reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, logging_batch_size_per_gpu: Union[str, int] = "auto", - config: Optional[Union[_PATH, dict[str, Any]]] = None, + config: Optional[Union[_PATH, Dict[str, Any]]] = None, logging_level: int = logging.WARN, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, initial_scale_power: int = 16, @@ -381,8 +380,8 @@ def restore_checkpoint_after_setup(self) -> bool: @override def _setup_model_and_optimizers( - self, model: Module, optimizers: list[Optimizer] - ) -> tuple["deepspeed.DeepSpeedEngine", list[Optimizer]]: + self, model: Module, optimizers: List[Optimizer] + ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: """Setup a model and multiple optimizers together. Currently only a single optimizer is supported. @@ -412,7 +411,7 @@ def _setup_model_and_optimizer( model: Module, optimizer: Optional[Optimizer], lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, - ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]: + ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. This calls ``deepspeed.initialize`` internally. @@ -453,7 +452,7 @@ def init_deepspeed(self) -> None: else: self._initialize_deepspeed_inference(self.model) - def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]: + def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig]]: assert self.lightning_module is not None optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(lr_schedulers) > 1: @@ -573,7 +572,7 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: @property @override - def distributed_sampler_kwargs(self) -> dict[str, int]: + def distributed_sampler_kwargs(self) -> Dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @override @@ -609,7 +608,7 @@ def _multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 @override - def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -646,7 +645,7 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing @@ -709,9 +708,9 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any], strict: bool) -> None: assert self.lightning_module is not None def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: list[str] = [] - unexpected_keys: list[str] = [] - error_msgs: list[str] = [] + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it @@ -781,7 +780,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: + def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -842,7 +841,7 @@ def _create_default_config( overlap_events: bool, thread_count: int, **zero_kwargs: Any, - ) -> dict: + ) -> Dict: cfg = { "activation_checkpointing": { "partition_activations": partition_activations, diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index bfbf99e..ab6e579 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -13,7 +13,6 @@ # limitations under the License. import logging import shutil -from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path @@ -21,8 +20,15 @@ TYPE_CHECKING, Any, Callable, + Dict, + Generator, + List, Literal, + Mapping, Optional, + Set, + Tuple, + Type, Union, ) @@ -82,7 +88,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy - _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] + _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] @@ -142,12 +148,12 @@ class FSDPStrategy(ParallelStrategy): """ strategy_name = "fsdp" - _registered_strategies: list[str] = [] + _registered_strategies: List[str] = [] def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -156,11 +162,11 @@ def __init__( cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "full", - device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, + device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -236,7 +242,7 @@ def precision_plugin(self, precision_plugin: Optional[FSDPPrecision]) -> None: @property @override - def distributed_sampler_kwargs(self) -> dict: + def distributed_sampler_kwargs(self) -> Dict: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -449,7 +455,7 @@ def reduce( return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def _determine_device_ids(self) -> list[int]: + def _determine_device_ids(self) -> List[int]: return [self.root_device.index] @override @@ -475,7 +481,7 @@ def teardown(self) -> None: self.accelerator.teardown() @classmethod - def get_registered_strategies(cls) -> list[str]: + def get_registered_strategies(cls) -> List[str]: return cls._registered_strategies @classmethod @@ -499,7 +505,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: cls._registered_strategies.append("fsdp_cpu_offload") @override - def lightning_module_state_dict(self) -> dict[str, Any]: + def lightning_module_state_dict(self) -> Dict[str, Any]: assert self.model is not None if self._state_dict_type == "sharded": state_dict_ctx = _get_sharded_state_dict_context(self.model) @@ -516,7 +522,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr pass @override - def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import OptimStateKeyType @@ -545,7 +551,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: @override def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: if storage_options is not None: raise TypeError( @@ -580,7 +586,7 @@ def save_checkpoint( raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index aa207a5..05e3fed 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -18,7 +18,7 @@ import tempfile from contextlib import suppress from dataclasses import dataclass -from typing import Any, Callable, Literal, NamedTuple, Optional, Union +from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union import torch import torch.backends.cudnn @@ -80,7 +80,7 @@ def __init__( f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - self.procs: list[mp.Process] = [] + self.procs: List[mp.Process] = [] self._already_fit = False @property @@ -224,7 +224,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) - def get_extra_results(self, trainer: "pl.Trainer") -> dict[str, Any]: + def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To avoid issues with memory sharing, we convert tensors to bytes. @@ -242,7 +242,7 @@ def get_extra_results(self, trainer: "pl.Trainer") -> dict[str, Any]: # send tensors as bytes to avoid issues with memory sharing return {"callback_metrics_bytes": buffer.getvalue()} - def update_main_process_results(self, trainer: "pl.Trainer", extra: dict[str, Any]) -> None: + def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we convert bytes back to ``torch.Tensor``. @@ -265,7 +265,7 @@ def kill(self, signum: _SIGNUM) -> None: with suppress(ProcessLookupError): os.kill(proc.pid, signum) - def __getstate__(self) -> dict: + def __getstate__(self) -> Dict: state = self.__dict__.copy() state["procs"] = [] # SpawnProcess can't be pickled return state @@ -276,7 +276,7 @@ class _WorkerOutput(NamedTuple): weights_path: Optional[_PATH] trainer_state: TrainerState trainer_results: Any - extra: dict[str, Any] + extra: Dict[str, Any] @dataclass @@ -301,7 +301,7 @@ class _GlobalStateSnapshot: use_deterministic_algorithms: bool use_deterministic_algorithms_warn_only: bool cudnn_benchmark: bool - rng_states: dict[str, Any] + rng_states: Dict[str, Any] @classmethod def capture(cls) -> "_GlobalStateSnapshot": diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index b7ec294..d2035d0 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -14,7 +14,7 @@ import logging import os import subprocess -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -77,7 +77,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.procs: list[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher + self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher @property @override diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index 82fec20..fb45166 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import shutil -from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -115,7 +114,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> dict[str, Any]: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: assert self.device_mesh is not None data_parallel_mesh = self.device_mesh["data_parallel"] return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @@ -238,7 +237,7 @@ def reduce( return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def _determine_device_ids(self) -> list[int]: + def _determine_device_ids(self) -> List[int]: return [self.root_device.index] @override @@ -250,7 +249,7 @@ def teardown(self) -> None: self.accelerator.teardown() @override - def lightning_module_state_dict(self) -> dict[str, Any]: + def lightning_module_state_dict(self) -> Dict[str, Any]: """Collects the state dict of the model. Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``. @@ -268,7 +267,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr pass @override - def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]: + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]: """Collects the state of the given optimizer. Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``. @@ -297,7 +296,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: @override def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: if storage_options is not None: raise TypeError( @@ -329,7 +328,7 @@ def save_checkpoint( return super().save_checkpoint(checkpoint=checkpoint, filepath=path) @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) state = { diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index 285d407..5658438 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Optional +from typing import Any, Dict, Generator, List, Optional import torch from torch import Tensor @@ -34,7 +33,7 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -72,15 +71,15 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[list[torch.device]]: + def parallel_devices(self) -> Optional[List[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> dict[str, Any]: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: return { "num_replicas": len(self.parallel_devices) if self.parallel_devices is not None else 0, "rank": self.global_rank, diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 0a0f52e..314007f 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -13,9 +13,8 @@ # limitations under the License. import logging from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping from contextlib import contextmanager -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union import torch from torch import Tensor @@ -62,9 +61,9 @@ def __init__( self._model: Optional[Module] = None self._launcher: Optional[_Launcher] = None self._forward_redirection: _ForwardRedirection = _ForwardRedirection() - self._optimizers: list[Optimizer] = [] - self._lightning_optimizers: list[LightningOptimizer] = [] - self.lr_scheduler_configs: list[LRSchedulerConfig] = [] + self._optimizers: List[Optimizer] = [] + self._lightning_optimizers: List[LightningOptimizer] = [] + self.lr_scheduler_configs: List[LRSchedulerConfig] = [] @property def launcher(self) -> Optional[_Launcher]: @@ -100,11 +99,11 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: self._precision_plugin = precision_plugin @property - def optimizers(self) -> list[Optimizer]: + def optimizers(self) -> List[Optimizer]: return self._optimizers @optimizers.setter - def optimizers(self, optimizers: list[Optimizer]) -> None: + def optimizers(self, optimizers: List[Optimizer]) -> None: self._optimizers = optimizers self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers] @@ -171,7 +170,7 @@ def setup_precision_plugin(self) -> None: self.optimizers = optimizers self.lr_scheduler_configs = lr_scheduler_configs - def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom strategies. @@ -238,7 +237,7 @@ def optimizer_step( assert isinstance(model, pl.LightningModule) return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) - def _setup_model_and_optimizers(self, model: Module, optimizers: list[Optimizer]) -> tuple[Module, list[Optimizer]]: + def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -363,7 +362,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) @@ -471,13 +470,13 @@ def handles_gradient_accumulation(self) -> bool: """Whether the strategy handles gradient accumulation internally.""" return False - def lightning_module_state_dict(self) -> dict[str, Any]: + def lightning_module_state_dict(self) -> Dict[str, Any]: """Returns model state.""" assert self.lightning_module is not None return self.lightning_module.state_dict() def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -588,13 +587,13 @@ def _reset_optimizers_and_schedulers(self) -> None: self._lightning_optimizers = [] self.lr_scheduler_configs = [] - def __getstate__(self) -> dict: + def __getstate__(self) -> Dict: # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled state = dict(vars(self)) # copy state["_lightning_optimizers"] = [] return state - def __setstate__(self, state: dict) -> None: + def __setstate__(self, state: Dict) -> None: self.__dict__ = state self.optimizers = self.optimizers # re-create the `_lightning_optimizers` diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index faffb30..56aae90 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch from torch import Tensor @@ -49,7 +49,7 @@ class XLAStrategy(DDPStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[list[torch.device]] = None, + parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None, precision_plugin: Optional[XLAPrecision] = None, debug: bool = False, @@ -172,7 +172,7 @@ def _setup_model(self, model: Module) -> Module: # type: ignore @property @override - def distributed_sampler_kwargs(self) -> dict[str, int]: + def distributed_sampler_kwargs(self) -> Dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @override @@ -295,7 +295,7 @@ def set_world_ranks(self) -> None: @override def save_checkpoint( - self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: import torch_xla.core.xla_model as xm diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 012d1a2..4c3bc5e 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -14,7 +14,7 @@ import logging import signal from copy import deepcopy -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Type, Union from packaging.version import Version @@ -115,11 +115,7 @@ def _call_configure_model(trainer: "pl.Trainer") -> None: # we don't normally check for this before calling the hook. it is done here to avoid instantiating the context # managers if is_overridden("configure_model", trainer.lightning_module): - with ( - trainer.strategy.tensor_init_context(), - trainer.strategy.model_sharded_context(), - trainer.precision_plugin.module_init_context(), - ): + with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501 _call_lightning_module_hook(trainer, "configure_model") @@ -226,7 +222,7 @@ def _call_callback_hooks( pl_module._current_fx_name = prev_fx_name -def _call_callbacks_state_dict(trainer: "pl.Trainer") -> dict[str, dict]: +def _call_callbacks_state_dict(trainer: "pl.Trainer") -> Dict[str, dict]: """Called when saving a model checkpoint, calls and returns every callback's `state_dict`, keyed by `Callback.state_key`.""" callback_state_dicts = {} @@ -237,7 +233,7 @@ def _call_callbacks_state_dict(trainer: "pl.Trainer") -> dict[str, dict]: return callback_state_dicts -def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: +def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: """Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.""" pl_module = trainer.lightning_module if pl_module: @@ -253,7 +249,7 @@ def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: +def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint. Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using @@ -265,7 +261,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[s prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = "on_load_checkpoint" - callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") + callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") if callback_states is None: return @@ -289,9 +285,9 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: +def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" - callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") + callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") if callback_states is None: return diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 7af8f13..06f3ee3 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -15,7 +15,7 @@ import logging import os from collections import Counter -from typing import Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union import torch @@ -74,11 +74,11 @@ class _AcceleratorConnector: def __init__( self, - devices: Union[list[int], str, int] = "auto", + devices: Union[List[int], str, int] = "auto", num_nodes: int = 1, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, precision: Optional[_PRECISION_INPUT] = None, sync_batchnorm: bool = False, benchmark: Optional[bool] = None, @@ -123,7 +123,7 @@ def __init__( self._precision_flag: _PRECISION_INPUT_STR = "32-true" self._precision_plugin_flag: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: list[Union[int, torch.device, str]] = [] + self._parallel_devices: List[Union[int, torch.device, str]] = [] self._layer_sync: Optional[LayerSync] = TorchSyncBatchNorm() if sync_batchnorm else None self.checkpoint_io: Optional[CheckpointIO] = None @@ -166,7 +166,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], sync_batchnorm: bool, ) -> None: """This method checks: @@ -225,7 +225,7 @@ def _check_config_and_set_final_flags( precision_flag = _convert_precision_to_unified_args(precision) if plugins: - plugins_flags_types: dict[str, int] = Counter() + plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Precision): self._precision_plugin_flag = plugin @@ -310,7 +310,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index a60f907..2f2b619 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -14,9 +14,8 @@ import logging import os -from collections.abc import Sequence from datetime import timedelta -from typing import Optional, Union +from typing import Dict, List, Optional, Sequence, Union import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks @@ -47,12 +46,12 @@ def __init__(self, trainer: "pl.Trainer"): def on_trainer_init( self, - callbacks: Optional[Union[list[Callback], Callback]], + callbacks: Optional[Union[List[Callback], Callback]], enable_checkpointing: bool, enable_progress_bar: bool, default_root_dir: Optional[str], enable_model_summary: bool, - max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, + max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() @@ -140,7 +139,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: progress_bar_callback = TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) - def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None: + def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): @@ -196,7 +195,7 @@ def _attach_model_callbacks(self) -> None: trainer.callbacks = all_callbacks @staticmethod - def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: + def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. @@ -209,9 +208,9 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: if there were any present in the input. """ - tuner_callbacks: list[Callback] = [] - other_callbacks: list[Callback] = [] - checkpoint_callbacks: list[Callback] = [] + tuner_callbacks: List[Callback] = [] + other_callbacks: List[Callback] = [] + checkpoint_callbacks: List[Callback] = [] for cb in callbacks: if isinstance(cb, (BatchSizeFinder, LearningRateFinder)): @@ -224,7 +223,7 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: return tuner_callbacks + other_callbacks + checkpoint_callbacks -def _validate_callbacks_list(callbacks: list[Callback]) -> None: +def _validate_callbacks_list(callbacks: List[Callback]) -> None: stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] seen_callbacks = set() for callback in stateful_callbacks: diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 71cc5a1..a41f87d 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -14,7 +14,7 @@ import logging import os import re -from typing import Any, Optional +from typing import Any, Dict, Optional import torch from fsspec.core import url_to_fs @@ -44,7 +44,7 @@ def __init__(self, trainer: "pl.Trainer") -> None: self._ckpt_path: Optional[_PATH] = None # flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter` self._user_managed: bool = False - self._loaded_checkpoint: dict[str, Any] = {} + self._loaded_checkpoint: Dict[str, Any] = {} @property def _hpc_resume_path(self) -> Optional[str]: @@ -491,10 +491,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint) return checkpoint - def _get_lightning_module_state_dict(self) -> dict[str, Tensor]: + def _get_lightning_module_state_dict(self) -> Dict[str, Tensor]: return self.trainer.strategy.lightning_module_state_dict() - def _get_loops_state_dict(self) -> dict[str, Any]: + def _get_loops_state_dict(self) -> Dict[str, Any]: return { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 3e52730..1e84a2e 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any, Iterable, Optional, Tuple, Union import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler @@ -343,7 +342,7 @@ class _DataHookSelector: model: "pl.LightningModule" datamodule: Optional["pl.LightningDataModule"] - _valid_hooks: tuple[str, ...] = field( + _valid_hooks: Tuple[str, ...] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index 0dbdc4e..545749b 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional, Tuple, Union from typing_extensions import TypedDict @@ -20,8 +20,8 @@ class _FxValidator: class _LogOptions(TypedDict): - allowed_on_step: Union[tuple[bool], tuple[bool, bool]] - allowed_on_epoch: Union[tuple[bool], tuple[bool, bool]] + allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]] + allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]] default_on_step: bool default_on_epoch: bool @@ -166,7 +166,7 @@ def check_logging(cls, fx_name: str) -> None: @classmethod def get_default_logging_levels( cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> tuple[bool, bool]: + ) -> Tuple[bool, bool]: """Return default logging levels for given hook.""" fx_config = cls.functions[fx_name] assert fx_config is not None @@ -191,7 +191,7 @@ def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> No @classmethod def check_logging_and_get_default_levels( cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> tuple[bool, bool]: + ) -> Tuple[bool, bool]: """Check if the given hook name is allowed to log and return logging levels.""" cls.check_logging(fx_name) on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index ffc99a9..c4ab116 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Iterable, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index fdde19a..62cc784 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -33,8 +32,8 @@ from lightning.pytorch.utilities.warnings import PossibleUserWarning _VALUE = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors -_OUT_DICT = dict[str, Tensor] -_PBAR_DICT = dict[str, float] +_OUT_DICT = Dict[str, Tensor] +_PBAR_DICT = Dict[str, float] class _METRICS(TypedDict): @@ -334,7 +333,7 @@ def __init__(self, training: bool) -> None: self.dataloader_idx: Optional[int] = None @property - def result_metrics(self) -> list[_ResultMetric]: + def result_metrics(self) -> List[_ResultMetric]: return list(self.values()) def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], meta: _Metadata) -> int: @@ -457,7 +456,7 @@ def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() if not v.has_reset and self.dataloader_idx == v.meta.dataloader_idx) - def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> tuple[str, str]: + def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]: name = result_metric.meta.name forked_name = result_metric.meta.forked_name(on_step) add_dataloader_idx = result_metric.meta.add_dataloader_idx diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index e63fecd..05a9753 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -5,7 +5,7 @@ import threading from subprocess import call from types import FrameType -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, List, Set, Union import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment @@ -20,7 +20,7 @@ class _HandlersCompose: - def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None: + def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None: if not isinstance(signal_handlers, list): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers @@ -37,14 +37,14 @@ class _SignalConnector: def __init__(self, trainer: "pl.Trainer") -> None: self.received_sigterm = False self.trainer = trainer - self._original_handlers: dict[_SIGNUM, _HANDLER] = {} + self._original_handlers: Dict[_SIGNUM, _HANDLER] = {} def register_signal_handlers(self) -> None: self.received_sigterm = False self._original_handlers = self._get_current_signal_handlers() - sigusr_handlers: list[_HANDLER] = [] - sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn] + sigusr_handlers: List[_HANDLER] = [] + sigterm_handlers: List[_HANDLER] = [self._sigterm_notifier_fn] environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: @@ -123,7 +123,7 @@ def teardown(self) -> None: self._original_handlers = {} @staticmethod - def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: + def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: """Collects the currently assigned signal handlers.""" valid_signals = _SignalConnector._valid_signals() if not _IS_WINDOWS: @@ -132,7 +132,7 @@ def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: return {signum: signal.getsignal(signum) for signum in valid_signals} @staticmethod - def _valid_signals() -> set[signal.Signals]: + def _valid_signals() -> Set[signal.Signals]: """Returns all valid signals supported on the current platform.""" return signal.valid_signals() @@ -145,7 +145,7 @@ def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: if threading.current_thread() is threading.main_thread(): signal.signal(signum, handlers) # type: ignore[arg-type] - def __getstate__(self) -> dict: + def __getstate__(self) -> Dict: state = self.__dict__.copy() state["_original_handlers"] = {} return state diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 0509f28..23db90f 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -23,10 +23,9 @@ import logging import math import os -from collections.abc import Generator, Iterable from contextlib import contextmanager from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Union from weakref import proxy import torch @@ -91,17 +90,17 @@ def __init__( *, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[list[int], str, int] = "auto", + devices: Union[List[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, - callbacks: Optional[Union[list[Callback], Callback]] = None, + callbacks: Optional[Union[List[Callback], Callback]] = None, fast_dev_run: Union[int, bool] = False, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: int = -1, min_steps: Optional[int] = None, - max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, + max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, limit_train_batches: Optional[Union[int, float]] = None, limit_val_batches: Optional[Union[int, float]] = None, limit_test_batches: Optional[Union[int, float]] = None, @@ -124,7 +123,7 @@ def __init__( profiler: Optional[Union[Profiler, str]] = None, detect_anomaly: bool = False, barebones: bool = False, - plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, @@ -473,7 +472,7 @@ def __init__( setup._init_profiler(self, profiler) # init logger flags - self._loggers: list[Logger] + self._loggers: List[Logger] self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags @@ -1150,7 +1149,7 @@ def num_nodes(self) -> int: return getattr(self.strategy, "num_nodes", 1) @property - def device_ids(self) -> list[int]: + def device_ids(self) -> List[int]: """List of device indexes per node.""" devices = ( self.strategy.parallel_devices @@ -1177,15 +1176,15 @@ def lightning_module(self) -> "pl.LightningModule": return self.strategy.lightning_module # type: ignore[return-value] @property - def optimizers(self) -> list[Optimizer]: + def optimizers(self) -> List[Optimizer]: return self.strategy.optimizers @optimizers.setter - def optimizers(self, new_optims: list[Optimizer]) -> None: + def optimizers(self, new_optims: List[Optimizer]) -> None: self.strategy.optimizers = new_optims @property - def lr_scheduler_configs(self) -> list[LRSchedulerConfig]: + def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: return self.strategy.lr_scheduler_configs @property @@ -1248,7 +1247,7 @@ def training_step(self, batch, batch_idx): return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: + def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs return None @@ -1281,7 +1280,7 @@ def early_stopping_callback(self) -> Optional[EarlyStopping]: return callbacks[0] if len(callbacks) > 0 else None @property - def early_stopping_callbacks(self) -> list[EarlyStopping]: + def early_stopping_callbacks(self) -> List[EarlyStopping]: """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @@ -1294,7 +1293,7 @@ def checkpoint_callback(self) -> Optional[Checkpoint]: return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> list[Checkpoint]: + def checkpoint_callbacks(self) -> List[Checkpoint]: """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, Checkpoint)] @@ -1523,14 +1522,14 @@ def num_training_batches(self) -> Union[int, float]: return self.fit_loop.max_batches @property - def num_sanity_val_batches(self) -> list[Union[int, float]]: + def num_sanity_val_batches(self) -> List[Union[int, float]]: """The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches # re-compute the `min` in case this is called outside the sanity-checking stage return [min(self.num_sanity_val_steps, batches) for batches in max_batches] @property - def num_val_batches(self) -> list[Union[int, float]]: + def num_val_batches(self) -> List[Union[int, float]]: """The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop.max_batches @@ -1539,12 +1538,12 @@ def num_val_batches(self) -> list[Union[int, float]]: return self.fit_loop.epoch_loop.val_loop._max_batches @property - def num_test_batches(self) -> list[Union[int, float]]: + def num_test_batches(self) -> List[Union[int, float]]: """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches @property - def num_predict_batches(self) -> list[Union[int, float]]: + def num_predict_batches(self) -> List[Union[int, float]]: """The number of prediction batches that will be used during ``trainer.predict()``.""" return self.predict_loop.max_batches @@ -1585,7 +1584,7 @@ def logger(self, logger: Optional[Logger]) -> None: self.loggers = [logger] @property - def loggers(self) -> list[Logger]: + def loggers(self) -> List[Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1597,7 +1596,7 @@ def loggers(self) -> list[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[list[Logger]]) -> None: + def loggers(self, loggers: Optional[List[Logger]]) -> None: self._loggers = loggers if loggers else [] @property diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 99badd8..6618f7e 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os import uuid from copy import deepcopy -from typing import Any, Optional +from typing import Any, Dict, Optional, Tuple import lightning.pytorch as pl from lightning.pytorch.utilities.memory import garbage_collection_cuda, is_oom_error @@ -98,7 +98,7 @@ def _scale_batch_size( return new_size -def __scale_batch_dump_params(trainer: "pl.Trainer") -> dict[str, Any]: +def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: dumped_params = { "loggers": trainer.loggers, "callbacks": trainer.callbacks, @@ -138,7 +138,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N loop.verbose = False -def __scale_batch_restore_params(trainer: "pl.Trainer", params: dict[str, Any]) -> None: +def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: # TODO: There are more states that needs to be reset (#4512 and #4870) trainer.loggers = params["loggers"] trainer.callbacks = params["callbacks"] @@ -169,7 +169,7 @@ def _run_power_scaling( new_size: int, batch_arg_name: str, max_trials: int, - params: dict[str, Any], + params: Dict[str, Any], ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -211,7 +211,7 @@ def _run_binary_scaling( new_size: int, batch_arg_name: str, max_trials: int, - params: dict[str, Any], + params: Dict[str, Any], ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. @@ -276,7 +276,7 @@ def _adjust_batch_size( factor: float = 1.0, value: Optional[int] = None, desc: Optional[str] = None, -) -> tuple[int, bool]: +) -> Tuple[int, bool]: """Helper function for adjusting the batch size. Args: @@ -328,7 +328,7 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None: loop.epoch_loop.val_loop.setup_data() -def _try_loop_run(trainer: "pl.Trainer", params: dict[str, Any]) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: loop = trainer._active_loop assert loop is not None loop.load_state_dict(deepcopy(params["loop_state_dict"])) diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index b50bedb..d756d3d 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -101,7 +101,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) - self.lr_max = lr_max self.num_training = num_training - self.results: dict[str, Any] = {} + self.results: Dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: @@ -310,7 +310,7 @@ def _lr_find( return lr_finder -def __lr_finder_dump_params(trainer: "pl.Trainer") -> dict[str, Any]: +def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: return { "optimizers": trainer.strategy.optimizers, "lr_scheduler_configs": trainer.strategy.lr_scheduler_configs, @@ -335,7 +335,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.limit_val_batches = num_training -def __lr_finder_restore_params(trainer: "pl.Trainer", params: dict[str, Any]) -> None: +def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: trainer.strategy.optimizers = params["optimizers"] trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"] trainer.callbacks = params["callbacks"] @@ -376,8 +376,8 @@ def __init__( self.num_training = num_training self.early_stop_threshold = early_stop_threshold self.beta = beta - self.losses: list[float] = [] - self.lrs: list[float] = [] + self.losses: List[float] = [] + self.lrs: List[float] = [] self.avg_loss = 0.0 self.best_loss = 0.0 self.progress_bar_refresh_rate = progress_bar_refresh_rate @@ -463,7 +463,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: + def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -475,7 +475,7 @@ def get_lr(self) -> list[float]: return val @property - def lr(self) -> Union[float, list[float]]: + def lr(self) -> Union[float, List[float]]: return self._lr @@ -500,7 +500,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> list[float]: + def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -512,11 +512,11 @@ def get_lr(self) -> list[float]: return val @property - def lr(self) -> Union[float, list[float]]: + def lr(self) -> Union[float, List[float]]: return self._lr -def _try_loop_run(trainer: "pl.Trainer", params: dict[str, Any]) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: loop = trainer.fit_loop loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False diff --git a/src/lightning/pytorch/utilities/_pytree.py b/src/lightning/pytorch/utilities/_pytree.py index a0c7236..f5f48b4 100644 --- a/src/lightning/pytorch/utilities/_pytree.py +++ b/src/lightning/pytorch/utilities/_pytree.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List, Tuple from torch.utils._pytree import SUPPORTED_NODES, LeafSpec, PyTree, TreeSpec, _get_node_type, tree_unflatten @@ -15,7 +15,7 @@ def _is_leaf_or_primitive_container(pytree: PyTree) -> bool: return all(isinstance(child, (int, float, str)) for child in child_pytrees) -def _tree_flatten(pytree: PyTree) -> tuple[list[Any], TreeSpec]: +def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: """Copy of :func:`torch.utils._pytree.tree_flatten` using our custom leaf function.""" if _is_leaf_or_primitive_container(pytree): return [pytree], LeafSpec() @@ -24,8 +24,8 @@ def _tree_flatten(pytree: PyTree) -> tuple[list[Any], TreeSpec]: flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(pytree) - result: list[Any] = [] - children_specs: list[TreeSpec] = [] + result: List[Any] = [] + children_specs: List[TreeSpec] = [] for child in child_pytrees: flat, child_spec = _tree_flatten(child) result += flat @@ -34,6 +34,6 @@ def _tree_flatten(pytree: PyTree) -> tuple[list[Any], TreeSpec]: return result, TreeSpec(node_type, context, children_specs) -def _map_and_unflatten(fn: Any, values: list[Any], spec: TreeSpec) -> PyTree: +def _map_and_unflatten(fn: Any, values: List[Any], spec: TreeSpec) -> PyTree: """Utility function to apply a function and unflatten it.""" return tree_unflatten([fn(i) for i in values], spec) diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index 1e01297..eb7273b 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -19,12 +19,12 @@ from ast import literal_eval from contextlib import suppress from functools import wraps -from typing import Any, Callable, TypeVar, cast +from typing import Any, Callable, Type, TypeVar, cast _T = TypeVar("_T", bound=Callable[..., Any]) -def _parse_env_variables(cls: type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Examples: diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 9c89c99..9b0ceb0 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections.abc import Iterable, Iterator -from typing import Any, Callable, Literal, Optional, Union +from collections.abc import Iterable +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict, override @@ -22,15 +22,15 @@ from lightning.fabric.utilities.types import _Stateful from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten -_ITERATOR_RETURN = tuple[Any, int, int] # batch, batch_idx, dataloader_idx +_ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx class _ModeIterator(Iterator[_ITERATOR_RETURN]): - def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: + def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: if limits is not None and len(limits) != len(iterables): raise ValueError(f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(iterables)})") self.iterables = iterables - self.iterators: list[Iterator] = [] + self.iterators: List[Iterator] = [] self._idx = 0 # what would be batch_idx self.limits = limits @@ -51,7 +51,7 @@ def reset(self) -> None: self.iterators = [] self._idx = 0 - def __getstate__(self) -> dict[str, Any]: + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # workaround an inconvenient `NotImplementedError`: @@ -65,9 +65,9 @@ def __getstate__(self) -> dict[str, Any]: class _MaxSizeCycle(_ModeIterator): - def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: + def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: super().__init__(iterables, limits) - self._consumed: list[bool] = [] + self._consumed: List[bool] = [] @override def __next__(self) -> _ITERATOR_RETURN: @@ -121,7 +121,7 @@ def __len__(self) -> int: class _Sequential(_ModeIterator): - def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: + def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: super().__init__(iterables, limits) self._iterator_idx = 0 # what would be dataloader_idx @@ -206,8 +206,8 @@ def __len__(self) -> int: class _CombinationMode(TypedDict): - fn: Callable[[list[int]], int] - iterator: type[_ModeIterator] + fn: Callable[[List[int]], int] + iterator: Type[_ModeIterator] _SUPPORTED_MODES = { @@ -288,7 +288,7 @@ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode self._iterator: Optional[_ModeIterator] = None - self._limits: Optional[list[Union[int, float]]] = None + self._limits: Optional[List[Union[int, float]]] = None @property def iterables(self) -> Any: @@ -306,12 +306,12 @@ def batch_sampler(self) -> Any: return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self.flattened, self._spec) @property - def flattened(self) -> list[Any]: + def flattened(self) -> List[Any]: """Return the flat list of iterables.""" return self._flattened @flattened.setter - def flattened(self, flattened: list[Any]) -> None: + def flattened(self, flattened: List[Any]) -> None: """Setter to conveniently update the list of iterables.""" if len(flattened) != len(self._flattened): raise ValueError( @@ -322,12 +322,12 @@ def flattened(self, flattened: list[Any]) -> None: self._flattened = flattened @property - def limits(self) -> Optional[list[Union[int, float]]]: + def limits(self) -> Optional[List[Union[int, float]]]: """Optional limits per iterator.""" return self._limits @limits.setter - def limits(self, limits: Optional[Union[int, float, list[Union[int, float]]]]) -> None: + def limits(self, limits: Optional[Union[int, float, List[Union[int, float]]]]) -> None: if isinstance(limits, (int, float)): limits = [limits] * len(self.flattened) elif isinstance(limits, list) and len(limits) != len(self.flattened): @@ -375,11 +375,11 @@ def _dataset_length(self) -> int: fn = _SUPPORTED_MODES[self._mode]["fn"] return fn(lengths) - def _state_dicts(self) -> list[dict[str, Any]]: + def _state_dicts(self) -> List[Dict[str, Any]]: """Returns the list of state dicts for iterables in `self.flattened` that are stateful.""" return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)] - def _load_state_dicts(self, states: list[dict[str, Any]]) -> None: + def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None: """Loads the state dicts for iterables in `self.flattened` that are stateful.""" if not states: return @@ -401,5 +401,5 @@ def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: dataloader._iterator = None -def _get_iterables_lengths(iterables: list[Iterable]) -> list[Union[int, float]]: +def _get_iterables_lengths(iterables: List[Iterable]) -> List[Union[int, float]]: return [(float("inf") if (length := sized_len(iterable)) is None else length) for iterable in iterables] diff --git a/src/lightning/pytorch/utilities/consolidate_checkpoint.py b/src/lightning/pytorch/utilities/consolidate_checkpoint.py index 0dcf587..6f150ba 100644 --- a/src/lightning/pytorch/utilities/consolidate_checkpoint.py +++ b/src/lightning/pytorch/utilities/consolidate_checkpoint.py @@ -1,5 +1,5 @@ import re -from typing import Any +from typing import Any, Dict import torch @@ -7,7 +7,7 @@ from lightning.fabric.utilities.load import _load_distributed_checkpoint -def _format_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]: +def _format_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: """Converts the special FSDP checkpoint format to the standard format the Lightning Trainer can load.""" # Rename the model key checkpoint["state_dict"] = checkpoint.pop("model") diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 5c14561..b58142b 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from collections.abc import Generator, Iterable, Mapping, Sized from dataclasses import fields -from typing import Any, Optional, Union +from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance @@ -140,7 +139,7 @@ def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, -) -> tuple[tuple[Any], dict[str, Any]]: +) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -234,7 +233,7 @@ def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, -) -> dict[str, Any]: +) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation. diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py index 08a0230..f200d89 100644 --- a/src/lightning/pytorch/utilities/grads.py +++ b/src/lightning/pytorch/utilities/grads.py @@ -13,13 +13,13 @@ # limitations under the License. """Utilities to describe gradients.""" -from typing import Union +from typing import Dict, Union import torch from torch.nn import Module -def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> dict[str, float]: +def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> Dict[str, float]: """Compute each parameter's gradient's norm and their overall norm. The overall norm is computed over all gradients together, as if they diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 5db942b..6a5a914 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -31,17 +31,17 @@ """ import re -from typing import Any, Callable +from typing import Any, Callable, Dict, List from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.utilities.rank_zero import rank_zero_warn -_CHECKPOINT = dict[str, Any] +_CHECKPOINT = Dict[str, Any] -def _migration_index() -> dict[str, list[Callable[[_CHECKPOINT], _CHECKPOINT]]]: +def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], @@ -133,7 +133,7 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: return checkpoint -def _get_fit_loop_initial_state_1_6_0() -> dict: +def _get_fit_loop_initial_state_1_6_0() -> Dict: return { "epoch_loop.batch_loop.manual_loop.optim_step_progress": { "current": {"completed": 0, "ready": 0}, diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 2c5656e..1537c26 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -18,7 +18,7 @@ import threading import warnings from types import ModuleType, TracebackType -from typing import Any, Optional +from typing import Any, Dict, List, Optional, Tuple, Type from packaging.version import Version from typing_extensions import override @@ -32,13 +32,13 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn _log = logging.getLogger(__name__) -_CHECKPOINT = dict[str, Any] +_CHECKPOINT = Dict[str, Any] _lock = threading.Lock() def migrate_checkpoint( checkpoint: _CHECKPOINT, target_version: Optional[str] = None -) -> tuple[_CHECKPOINT, dict[str, list[str]]]: +) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: """Applies Lightning version migrations to a checkpoint dictionary. Args: @@ -121,7 +121,7 @@ class _FaultTolerantMode(LightningEnum): def __exit__( self, - exc_type: Optional[type[BaseException]], + exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> None: diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index 44591aa..36adedf 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -15,7 +15,7 @@ import inspect import logging import os -from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar from lightning_utilities.core.imports import RequirementCache from torch import nn @@ -26,7 +26,7 @@ _log = logging.getLogger(__name__) -def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[type[object]] = None) -> bool: +def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None) -> bool: if instance is None: # if `self.lightning_module` was passed as instance, it can be `None` return False @@ -65,7 +65,7 @@ class _ModuleMode: """Captures the ``nn.Module.training`` (bool) mode of every submodule, and allows it to be restored later on.""" def __init__(self) -> None: - self.mode: dict[str, bool] = {} + self.mode: Dict[str, bool] = {} def capture(self, module: nn.Module) -> None: self.mode.clear() @@ -108,10 +108,10 @@ class _restricted_classmethod_impl(Generic[_T, _P, _R_co]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" - def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: + def __init__(self, method: Callable[Concatenate[Type[_T], _P], _R_co]) -> None: self.method = method - def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]: + def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]: # The wrapper ensures that the method can be inspected, but not called on an instance @functools.wraps(self.method) def wrapper(*args: Any, **kwargs: Any) -> _R_co: diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 6a5baf2..c40dc94 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -17,7 +17,7 @@ import logging import math from collections import OrderedDict -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -73,8 +73,8 @@ def __init__(self, module: nn.Module) -> None: super().__init__() self._module = module self._hook_handle = self._register_hook() - self._in_size: Optional[Union[str, list]] = None - self._out_size: Optional[Union[str, list]] = None + self._in_size: Optional[Union[str, List]] = None + self._out_size: Optional[Union[str, List]] = None def __del__(self) -> None: self.detach_hook() @@ -121,11 +121,11 @@ def detach_hook(self) -> None: self._hook_handle.remove() @property - def in_size(self) -> Union[str, list]: + def in_size(self) -> Union[str, List]: return self._in_size or UNKNOWN_SIZE @property - def out_size(self) -> Union[str, list]: + def out_size(self) -> Union[str, List]: return self._out_size or UNKNOWN_SIZE @property @@ -221,8 +221,8 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: self._precision_megabytes = (precision / 8.0) * 1e-6 @property - def named_modules(self) -> list[tuple[str, nn.Module]]: - mods: list[tuple[str, nn.Module]] + def named_modules(self) -> List[Tuple[str, nn.Module]]: + mods: List[Tuple[str, nn.Module]] if self._max_depth == 0: mods = [] elif self._max_depth == 1: @@ -234,31 +234,31 @@ def named_modules(self) -> list[tuple[str, nn.Module]]: return mods @property - def layer_names(self) -> list[str]: + def layer_names(self) -> List[str]: return list(self._layer_summary.keys()) @property - def layer_types(self) -> list[str]: + def layer_types(self) -> List[str]: return [layer.layer_type for layer in self._layer_summary.values()] @property - def in_sizes(self) -> list: + def in_sizes(self) -> List: return [layer.in_size for layer in self._layer_summary.values()] @property - def out_sizes(self) -> list: + def out_sizes(self) -> List: return [layer.out_size for layer in self._layer_summary.values()] @property - def param_nums(self) -> list[int]: + def param_nums(self) -> List[int]: return [layer.num_parameters for layer in self._layer_summary.values()] @property - def training_modes(self) -> list[bool]: + def training_modes(self) -> List[bool]: return [layer.training for layer in self._layer_summary.values()] @property - def total_training_modes(self) -> dict[str, int]: + def total_training_modes(self) -> Dict[str, int]: modes = [layer.training for layer in self._model.modules()] modes = modes[1:] # exclude the root module return {"train": modes.count(True), "eval": modes.count(False)} @@ -279,7 +279,7 @@ def total_layer_params(self) -> int: def model_size(self) -> float: return self.total_parameters * self._precision_megabytes - def summarize(self) -> dict[str, LayerSummary]: + def summarize(self) -> Dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -318,7 +318,7 @@ def _forward_example_input(self) -> None: model(input_) mode.restore(model) - def _get_summary_data(self) -> list[tuple[str, list[str]]]: + def _get_summary_data(self) -> List[Tuple[str, List[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -341,7 +341,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]: return arrays - def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" layer_summaries = dict(arrays) layer_summaries[" "].append(" ") @@ -368,7 +368,7 @@ def __repr__(self) -> str: return str(self) -def parse_batch_shape(batch: Any) -> Union[str, list]: +def parse_batch_shape(batch: Any) -> Union[str, List]: if hasattr(batch, "shape"): return list(batch.shape) @@ -382,8 +382,8 @@ def _format_summary_table( total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: dict[str, int], - *cols: tuple[str, list[str]], + total_training_modes: Dict[str, int], + *cols: Tuple[str, List[str]], ) -> str: """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big string defining the summary table that are nicely formatted.""" diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py index 5038aeb..57d9ae5 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py @@ -14,6 +14,7 @@ """Utilities that can be used with Deepspeed.""" from collections import OrderedDict +from typing import Dict, List, Tuple import torch from lightning_utilities.core.imports import RequirementCache @@ -53,7 +54,7 @@ def partitioned_size(p: Parameter) -> int: class DeepSpeedSummary(ModelSummary): @override - def summarize(self) -> dict[str, DeepSpeedLayerSummary]: # type: ignore[override] + def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[override] summary = OrderedDict((name, DeepSpeedLayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -82,11 +83,11 @@ def trainable_parameters(self) -> int: ) @property - def parameters_per_layer(self) -> list[int]: + def parameters_per_layer(self) -> List[int]: return [layer.average_shard_parameters for layer in self._layer_summary.values()] @override - def _get_summary_data(self) -> list[tuple[str, list[str]]]: + def _get_summary_data(self) -> List[Tuple[str, List[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -111,7 +112,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]: return arrays @override - def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" super()._add_leftover_params_to_summary(arrays, total_leftover_params) layer_summaries = dict(arrays) diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py index da0309b..8680285 100644 --- a/src/lightning/pytorch/utilities/parameter_tying.py +++ b/src/lightning/pytorch/utilities/parameter_tying.py @@ -18,17 +18,17 @@ """ -from typing import Optional +from typing import Dict, List, Optional from torch import nn -def find_shared_parameters(module: nn.Module) -> list[str]: +def find_shared_parameters(module: nn.Module) -> List[str]: """Returns a list of names of shared parameters set in the module.""" return _find_shared_parameters(module) -def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[dict] = None, prefix: str = "") -> list[str]: +def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[Dict] = None, prefix: str = "") -> List[str]: if tied_parameters is None: tied_parameters = {} for name, param in module._parameters.items(): diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 16eef55..0f4460a 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -17,9 +17,8 @@ import inspect import pickle import types -from collections.abc import MutableMapping, Sequence from dataclasses import fields, is_dataclass -from typing import Any, Literal, Optional, Union +from typing import Any, Dict, List, Literal, MutableMapping, Optional, Sequence, Tuple, Type, Union from torch import nn @@ -49,7 +48,7 @@ def clean_namespace(hparams: MutableMapping) -> None: del hparams[k] -def parse_class_init_keys(cls: type) -> tuple[str, Optional[str], Optional[str]]: +def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]]: """Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: @@ -61,7 +60,7 @@ def parse_class_init_keys(cls: type) -> tuple[str, Optional[str], Optional[str]] ('self', 'my_args', 'my_kwargs') """ - init_parameters = inspect.signature(cls.__init__).parameters # type: ignore[misc] + init_parameters = inspect.signature(cls.__init__).parameters # docs claims the params are always ordered # https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters init_params = list(init_parameters.values()) @@ -69,7 +68,7 @@ def parse_class_init_keys(cls: type) -> tuple[str, Optional[str], Optional[str]] n_self = init_params[0].name def _get_first_if_any( - params: list[inspect.Parameter], + params: List[inspect.Parameter], param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], ) -> Optional[str]: for p in params: @@ -83,13 +82,13 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> dict[str, Any]: # pragma: no-cover +def get_init_args(frame: types.FrameType) -> Dict[str, Any]: # pragma: no-cover """For backwards compatibility: #16369.""" _, local_args = _get_init_args(frame) return local_args -def _get_init_args(frame: types.FrameType) -> tuple[Optional[Any], dict[str, Any]]: +def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: return None, {} @@ -110,10 +109,10 @@ def _get_init_args(frame: types.FrameType) -> tuple[Optional[Any], dict[str, Any def collect_init_args( frame: types.FrameType, - path_args: list[dict[str, Any]], + path_args: List[Dict[str, Any]], inside: bool = False, - classes: tuple[type, ...] = (), -) -> list[dict[str, Any]]: + classes: Tuple[Type, ...] = (), +) -> List[Dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. Args: @@ -148,7 +147,7 @@ def save_hyperparameters( *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, - given_hparams: Optional[dict[str, Any]] = None, + given_hparams: Optional[Dict[str, Any]] = None, ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -233,14 +232,14 @@ class AttributeDict(_AttributeDict): """ -def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> list[Any]: +def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> List[Any]: """Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - holders: list[Any] = [] + holders: List[Any] = [] # Check if attribute in model if hasattr(model, attribute): diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 7250ba5..4ba9e7f 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -13,8 +13,8 @@ # limitations under the License. """Utilities to help with reproducibility of models.""" -from collections.abc import Generator from contextlib import contextmanager +from typing import Generator from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913..03b3afd 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Dict, List, Optional, Tuple from lightning_utilities.core.imports import RequirementCache @@ -42,7 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, -) -> tuple[list[str], dict[str, bool]]: +) -> Tuple[List[str], Dict[str, bool]]: """Construct reasons for pytest skipif. Args: diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index 8fccfa7..c1b971e 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -17,13 +17,19 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ -from collections.abc import Generator, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, + Generator, + Iterator, + List, + Mapping, Optional, Protocol, + Sequence, + Tuple, + Type, TypedDict, Union, runtime_checkable, @@ -41,8 +47,8 @@ _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] -_EVALUATE_OUTPUT = list[Mapping[str, float]] # 1 dict per DataLoader -_PREDICT_OUTPUT = Union[list[Any], list[list[Any]]] +_EVALUATE_OUTPUT = List[Mapping[str, float]] # 1 dict per DataLoader +_PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] TRAIN_DATALOADERS = Any # any iterable or collection of iterables EVAL_DATALOADERS = Any # any iterable or collection of iterables @@ -54,7 +60,7 @@ class DistributedDataParallel(Protocol): def __init__( self, module: torch.nn.Module, - device_ids: Optional[list[Union[int, torch.device]]] = None, + device_ids: Optional[List[Union[int, torch.device]]] = None, output_device: Optional[Union[int, torch.device]] = None, dim: int = 0, broadcast_buffers: bool = True, @@ -73,7 +79,7 @@ def no_sync(self) -> Generator: ... # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] +LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]] LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau] @@ -113,7 +119,7 @@ class OptimizerLRSchedulerConfig(TypedDict): Union[ Optimizer, Sequence[Optimizer], - tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], + Tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], OptimizerLRSchedulerConfig, Sequence[OptimizerLRSchedulerConfig], ] diff --git a/src/lightning/pytorch/utilities/upgrade_checkpoint.py b/src/lightning/pytorch/utilities/upgrade_checkpoint.py index 04cf000..87ad603 100644 --- a/src/lightning/pytorch/utilities/upgrade_checkpoint.py +++ b/src/lightning/pytorch/utilities/upgrade_checkpoint.py @@ -16,6 +16,7 @@ from argparse import ArgumentParser, Namespace from pathlib import Path from shutil import copyfile +from typing import List import torch from tqdm import tqdm @@ -28,7 +29,7 @@ def _upgrade(args: Namespace) -> None: path = Path(args.path).absolute() extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" - files: list[Path] = [] + files: List[Path] = [] if not path.exists(): _log.error( diff --git a/src/lightning_fabric/__setup__.py b/src/lightning_fabric/__setup__.py index a55e1f2..8fe0bc0 100644 --- a/src/lightning_fabric/__setup__.py +++ b/src/lightning_fabric/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any +from typing import Any, Dict from pkg_resources import parse_requirements from setuptools import find_packages @@ -29,7 +29,7 @@ def _load_assistant() -> ModuleType: return _load_py_module("assistant", location) -def _prepare_extras() -> dict[str, Any]: +def _prepare_extras() -> Dict[str, Any]: assistant = _load_assistant() # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -49,7 +49,7 @@ def _prepare_extras() -> dict[str, Any]: return extras -def _setup_args() -> dict[str, Any]: +def _setup_args() -> Dict[str, Any]: assistant = _load_assistant() about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) @@ -73,7 +73,7 @@ def _setup_args() -> dict[str, Any]: "include_package_data": True, "zip_safe": False, "keywords": ["deep learning", "pytorch", "AI"], - "python_requires": ">=3.9", + "python_requires": ">=3.8", "setup_requires": ["wheel"], "install_requires": assistant.load_requirements( _PATH_REQUIREMENTS, unfreeze="none" if _FREEZE_REQUIREMENTS else "all" @@ -105,6 +105,7 @@ def _setup_args() -> dict[str, Any]: # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/pytorch_lightning/__setup__.py b/src/pytorch_lightning/__setup__.py index 6677b46..7eedace 100644 --- a/src/pytorch_lightning/__setup__.py +++ b/src/pytorch_lightning/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any +from typing import Any, Dict from pkg_resources import parse_requirements from setuptools import find_packages @@ -29,7 +29,7 @@ def _load_assistant() -> ModuleType: return _load_py_module("assistant", location) -def _prepare_extras() -> dict[str, Any]: +def _prepare_extras() -> Dict[str, Any]: assistant = _load_assistant() # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -49,7 +49,7 @@ def _prepare_extras() -> dict[str, Any]: return extras -def _setup_args() -> dict[str, Any]: +def _setup_args() -> Dict[str, Any]: assistant = _load_assistant() about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) @@ -80,7 +80,7 @@ def _setup_args() -> dict[str, Any]: "long_description_content_type": "text/markdown", "zip_safe": False, "keywords": ["deep learning", "pytorch", "AI"], - "python_requires": ">=3.9", + "python_requires": ">=3.8", "setup_requires": ["wheel"], # TODO: aggregate pytorch and lite requirements as we include its source code directly in this package. # this is not a problem yet because lite's base requirements are all included in pytorch's base requirements @@ -107,6 +107,7 @@ def _setup_args() -> dict[str, Any]: "Operating System :: OS Independent", # Specify the Python versions you support here. "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/run_standalone_tests.sh b/tests/run_standalone_tests.sh index 9aa54f7..75a52e1 100755 --- a/tests/run_standalone_tests.sh +++ b/tests/run_standalone_tests.sh @@ -17,7 +17,7 @@ set -e # Batch size for testing: Determines how many standalone test invocations run in parallel # It can be set through the env variable PL_STANDALONE_TESTS_BATCH_SIZE and defaults to 6 if not set -test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-6}" +test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-3}" source="${PL_STANDALONE_TESTS_SOURCE:-"lightning"}" # this is the directory where the tests are located test_dir=$1 # parse the first argument diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index 0aed367..e323ada 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -121,32 +121,27 @@ def test_tf32_message(_, __, ___, caplog, monkeypatch): def test_find_usable_cuda_devices_error_handling(): """Test error handling for edge cases when using `find_usable_cuda_devices`.""" # Asking for GPUs if no GPUs visible - with ( - mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), - pytest.raises(ValueError, match="You requested to find 2 devices but there are no visible CUDA"), + with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), pytest.raises( + ValueError, match="You requested to find 2 devices but there are no visible CUDA" ): find_usable_cuda_devices(2) # Asking for more GPUs than are visible - with ( - mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), - pytest.raises(ValueError, match="this machine only has 1 GPUs"), + with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), pytest.raises( + ValueError, match="this machine only has 1 GPUs" ): find_usable_cuda_devices(2) # All GPUs are unusable tensor_mock = Mock(side_effect=RuntimeError) # simulate device placement fails - with ( - mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), - mock.patch("lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock), - pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")), - ): + with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), mock.patch( + "lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock + ), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")): find_usable_cuda_devices(2) # Request for as many GPUs as there are, no error should be raised - with ( - mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), - mock.patch("lightning.fabric.accelerators.cuda.torch.tensor"), + with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), mock.patch( + "lightning.fabric.accelerators.cuda.torch.tensor" ): assert find_usable_cuda_devices(-1) == [0, 1, 2, 3, 4] diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 2544df1..e8f39b6 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Dict import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator @@ -30,7 +30,7 @@ def __init__(self, param1, param2): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> dict[str, Any]: + def get_device_stats(self, device: torch.device) -> Dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_fabric/accelerators/test_xla.py b/tests/tests_fabric/accelerators/test_xla.py index 7a906c8..1af7d7e 100644 --- a/tests/tests_fabric/accelerators/test_xla.py +++ b/tests/tests_fabric/accelerators/test_xla.py @@ -44,8 +44,3 @@ def test_get_parallel_devices_raises(tpu_available): XLAAccelerator.get_parallel_devices(5) with pytest.raises(ValueError, match="Could not parse.*anything-else'"): XLAAccelerator.get_parallel_devices("anything-else") - - -@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") -def test_instantiate_xla_accelerator(): - _ = XLAAccelerator() diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 889decd..5fdc61a 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -14,8 +14,8 @@ import os import sys import threading -from concurrent.futures.process import _ExecutorManagerThread from pathlib import Path +from typing import List from unittest.mock import Mock import lightning.fabric @@ -25,6 +25,9 @@ from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection +if sys.version_info >= (3, 9): + from concurrent.futures.process import _ExecutorManagerThread + @pytest.fixture(autouse=True) def preserve_global_rank_variable(): @@ -66,6 +69,7 @@ def restore_env_variables(): "OMP_NUM_THREADS", # set by our launchers # set by torchdynamo "TRITON_CACHE_DIR", + "TORCHINDUCTOR_CACHE_DIR", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -197,7 +201,7 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: +def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: """An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`""" initial_size = len(items) conditions = [] diff --git a/tests/tests_fabric/helpers/datasets.py b/tests/tests_fabric/helpers/datasets.py index ee14b21..211e1f3 100644 --- a/tests/tests_fabric/helpers/datasets.py +++ b/tests/tests_fabric/helpers/datasets.py @@ -1,4 +1,4 @@ -from collections.abc import Iterator +from typing import Iterator import torch from torch import Tensor diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index b4c223e..c8deb9d 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -29,16 +29,13 @@ @contextlib.contextmanager def check_destroy_group(): - with ( - mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", - wraps=TorchCollective.new_group, - ) as mock_new, - mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", - wraps=TorchCollective.destroy_group, - ) as mock_destroy, - ): + with mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", + wraps=TorchCollective.new_group, + ) as mock_new, mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", + wraps=TorchCollective.destroy_group, + ) as mock_destroy: yield # 0 to account for tests that mock distributed # -1 to account for destroying the default process group @@ -158,10 +155,9 @@ def test_repeated_create_and_destroy(): with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"): collective.create_group() - with ( - mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), - mock.patch("torch.distributed.destroy_process_group") as destroy_mock, - ): + with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch( + "torch.distributed.destroy_process_group" + ) as destroy_mock: collective.teardown() # this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default # group @@ -304,11 +300,9 @@ def test_collective_manages_default_group(): assert TorchCollective.manages_default_group - with ( - mock.patch.object(collective, "_group") as mock_group, - mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}), - mock.patch("torch.distributed.destroy_process_group") as destroy_mock, - ): + with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict( + "torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)} + ), mock.patch("torch.distributed.destroy_process_group") as destroy_mock: collective.teardown() destroy_mock.assert_called_once_with(mock_group) diff --git a/tests/tests_fabric/plugins/environments/test_lsf.py b/tests/tests_fabric/plugins/environments/test_lsf.py index 4e60d96..b444f6f 100644 --- a/tests/tests_fabric/plugins/environments/test_lsf.py +++ b/tests/tests_fabric/plugins/environments/test_lsf.py @@ -41,9 +41,8 @@ def test_empty_lsb_djob_rankfile(): def test_missing_lsb_job_id(tmp_path): """Test an error when the job id cannot be found.""" - with ( - mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), - pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"), + with mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises( + ValueError, match="Could not find job id in environment variable LSB_JOBID" ): LSFEnvironment() diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index f237478..73457ed 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -155,9 +155,8 @@ def test_srun_variable_validation(): """Test that we raise useful errors when `srun` variables are misconfigured.""" with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}): SLURMEnvironment() - with ( - mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), - pytest.raises(RuntimeError, match="You set `--ntasks=2` in your SLURM"), + with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), pytest.raises( + RuntimeError, match="You set `--ntasks=2` in your SLURM" ): SLURMEnvironment() diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index 6c595fb..b63d443 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -100,11 +100,8 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with ( - mock.patch( - "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), - pytest.raises(RuntimeError, match="requires that your script guards the main"), - ): + with mock.patch( + "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), pytest.raises(RuntimeError, match="requires that your script guards the main"): launcher.launch(function=Mock()) diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index b98d5f8..56d9875 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -58,12 +58,9 @@ def test_ddp_no_backward_sync(): strategy = DDPStrategy() assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl) - with ( - pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" - ), - strategy._backward_sync_control.no_backward_sync(Mock(), True), - ): + with pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" + ), strategy._backward_sync_control.no_backward_sync(Mock(), True): pass module = MagicMock(spec=DistributedDataParallel) diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 4811599..3be535e 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -404,11 +404,9 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init): fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision="bf16-true") fabric.launch() - with ( - mock.patch("deepspeed.zero.Init") as zero_init_mock, - mock.patch("torch.Tensor.uniform_") as init_mock, - fabric.init_module(empty_init=empty_init), - ): + with mock.patch("deepspeed.zero.Init") as zero_init_mock, mock.patch( + "torch.Tensor.uniform_" + ) as init_mock, fabric.init_module(empty_init=empty_init): model = BoringModel() zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index cb6542c..0c46e7a 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -133,12 +133,9 @@ def test_no_backward_sync(): strategy = FSDPStrategy() assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl) - with ( - pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" - ), - strategy._backward_sync_control.no_backward_sync(Mock(), True), - ): + with pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" + ), strategy._backward_sync_control.no_backward_sync(Mock(), True): pass module = MagicMock(spec=FullyShardedDataParallel) @@ -175,12 +172,9 @@ def __init__(self): assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) strategy._parallel_devices = [torch.device("cuda", 0)] - with ( - mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), - mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock, - ): + with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock: wrapped = strategy.setup_module(Model()) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index dfbdb16..b04a29b 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -29,6 +29,13 @@ from tests_fabric.helpers.runif import RunIf +@pytest.fixture() +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class FeedForward(nn.Module): def __init__(self): super().__init__() @@ -81,7 +88,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -116,11 +123,28 @@ def test_setup_device_mesh(): assert fabric.strategy.device_mesh.size(1) == 4 +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor - strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_tp) + parallelize = _parallelize_feed_forward_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + + strategy = ModelParallelStrategy(parallelize_fn=parallelize) fabric = Fabric(accelerator="auto", devices=2, strategy=strategy) fabric.launch() @@ -161,9 +185,18 @@ def test_tensor_parallel(): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor + parallelize = _parallelize_feed_forward_fsdp2_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2_tp, data_parallel_size=2, @@ -238,6 +271,7 @@ def _train(fabric, model=None, optimizer=None): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) +@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize( "precision", [ @@ -245,7 +279,7 @@ def _train(fabric, model=None, optimizer=None): pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)), ], ) -def test_train_save_load(precision, tmp_path): +def test_train_save_load(distributed, precision, tmp_path): """Test 2D-parallel training, saving and loading precision settings.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -303,7 +337,7 @@ def test_train_save_load(precision, tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_full_state_dict(tmp_path): +def test_save_full_state_dict(distributed, tmp_path): """Test that ModelParallelStrategy saves the full state into a single file with `save_distributed_checkpoint=False`.""" from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict @@ -404,7 +438,7 @@ def test_save_full_state_dict(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_full_state_dict_into_sharded_model(tmp_path): +def test_load_full_state_dict_into_sharded_model(distributed, tmp_path): """Test that the strategy can load a full-state checkpoint into a distributed model.""" fabric = Fabric(accelerator="cuda", devices=1) fabric.seed_everything(0) @@ -450,7 +484,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("move_to_device", [True, False]) @mock.patch("lightning.fabric.wrappers._FabricModule") -def test_setup_module_move_to_device(fabric_module_mock, move_to_device): +def test_setup_module_move_to_device(fabric_module_mock, move_to_device, distributed): """Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device (sharding).""" from torch.distributed._tensor import DTensor @@ -482,7 +516,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype): +def test_module_init_context(distributed, precision, expected_dtype): """Test that the module under the init-context gets moved to the right device and dtype.""" strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision) @@ -505,7 +539,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_filter(tmp_path): +def test_save_filter(distributed, tmp_path): strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2, save_distributed_checkpoint=False, @@ -558,7 +592,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh): "val", ], ) -def test_clip_gradients(clip_type, precision): +def test_clip_gradients(distributed, clip_type, precision): strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2) fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy) fabric.launch() @@ -600,7 +634,7 @@ def test_clip_gradients(clip_type, precision): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) -def test_save_sharded_and_consolidate_and_load(tmp_path): +def test_save_sharded_and_consolidate_and_load(distributed, tmp_path): """Test the consolidation of a distributed (DTensor) checkpoint into a single file.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -657,7 +691,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_raw_module_state(): +def test_load_raw_module_state(distributed): from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index 879a55c..e2864b6 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -48,12 +48,9 @@ def test_xla_fsdp_no_backward_sync(): strategy = XLAFSDPStrategy() assert isinstance(strategy._backward_sync_control, _XLAFSDPBackwardSyncControl) - with ( - pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" - ), - strategy._backward_sync_control.no_backward_sync(object(), True), - ): + with pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" + ), strategy._backward_sync_control.no_backward_sync(object(), True): pass module = MagicMock(spec=XlaFullyShardedDataParallel) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 8a6e920..08d6dbb 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -15,7 +15,7 @@ import os import sys from contextlib import nullcontext -from typing import Any +from typing import Any, Dict from unittest import mock from unittest.mock import Mock @@ -165,7 +165,7 @@ class Accel(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> dict[str, Any]: + def get_device_stats(self, device: torch.device) -> Dict[str, Any]: pass def teardown(self) -> None: @@ -960,33 +960,28 @@ def test_arguments_from_environment_collision(): with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}): _Connector(accelerator="cuda") - with ( - mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), - pytest.raises(ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"), + with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises( + ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`" ): _Connector(accelerator="cuda") - with ( - mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), - pytest.raises(ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"), + with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises( + ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`" ): _Connector(strategy="ddp_spawn") - with ( - mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), - pytest.raises(ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"), + with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises( + ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`" ): _Connector(devices=3) - with ( - mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), - pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"), + with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises( + ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`" ): _Connector(num_nodes=2) - with ( - mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), - pytest.raises(ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"), + with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises( + ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`" ): _Connector(precision="64-true") diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 7bb6b29..70d04d5 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -746,10 +746,9 @@ def test_no_backward_sync(): # pretend that the strategy does not support skipping backward sync fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None) - with ( - pytest.warns(PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"), - fabric.no_backward_sync(model), - ): + with pytest.warns( + PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the" + ), fabric.no_backward_sync(model): pass # for single-device strategies, it becomes a no-op without warning diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py index 2584aab..216b77e 100644 --- a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -41,9 +41,8 @@ def test_parse_cli_args(args, expected): def test_process_cli_args(tmp_path, caplog, monkeypatch): # PyTorch version < 2.3 monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False) - with ( - caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), - pytest.raises(SystemExit), + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit ): _process_cli_args(Namespace()) assert "requires PyTorch >= 2.3." in caplog.text @@ -52,9 +51,8 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint does not exist checkpoint_folder = Path("does/not/exist") - with ( - caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), - pytest.raises(SystemExit), + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit ): _process_cli_args(Namespace(checkpoint_folder=checkpoint_folder)) assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text @@ -63,9 +61,8 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not a folder file = tmp_path / "checkpoint_file" file.touch() - with ( - caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), - pytest.raises(SystemExit), + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit ): _process_cli_args(Namespace(checkpoint_folder=file)) assert "checkpoint path must be a folder" in caplog.text @@ -74,9 +71,8 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not an FSDP checkpoint folder = tmp_path / "checkpoint_folder" folder.mkdir() - with ( - caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), - pytest.raises(SystemExit), + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit ): _process_cli_args(Namespace(checkpoint_folder=folder)) assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text @@ -93,9 +89,8 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint is a FSDP folder, output file already exists file = tmp_path / "ouput_file" file.touch() - with ( - caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), - pytest.raises(SystemExit), + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit ): _process_cli_args(Namespace(checkpoint_folder=folder, output_file=file)) assert "path for the converted checkpoint already exists" in caplog.text diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index f5a78a1..cc6c23b 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -215,10 +215,9 @@ def test_infinite_barrier(): # distributed available barrier = _InfiniteBarrier() - with ( - mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True), - mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock, - ): + with mock.patch( + "lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True + ), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: barrier.__enter__() dist_mock.new_group.assert_called_once() assert barrier.barrier == barrier.group.monitored_barrier diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index d410d07..eefadb2 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -39,9 +39,8 @@ def test_get_available_flops(xla_available): with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None - with ( - pytest.warns(match="t4' does not support torch.bfloat"), - mock.patch("torch.cuda.get_device_name", return_value="t4"), + with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch( + "torch.cuda.get_device_name", return_value="t4" ): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 6967bff..7654125 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Dict import torch from lightning.pytorch import Trainer @@ -24,7 +24,7 @@ class TestAccelerator(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> dict[str, Any]: + def get_device_stats(self, device: torch.device) -> Dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index 8445560..cd34fe3 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Union +from typing import Any, Dict, Union from unittest.mock import Mock import lightning.pytorch as pl @@ -53,7 +53,7 @@ def setup(self, trainer: "pl.Trainer") -> None: def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.setup_called == restore_after_pre_setup return super().load_checkpoint(checkpoint_path) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 89c1eff..b8d3d6d 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -141,12 +141,9 @@ def on_train_start(self) -> None: model = TestModel() - with ( - mock.patch( - "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True - ) as mock_progress_stop, - pytest.raises(SystemExit), - ): + with mock.patch( + "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True + ) as mock_progress_stop, pytest.raises(SystemExit): progress_bar = RichProgressBar() trainer = Trainer( default_root_dir=tmp_path, diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index f1d999f..aacee95 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -14,7 +14,7 @@ import csv import os import re -from typing import Optional +from typing import Dict, Optional from unittest import mock from unittest.mock import Mock @@ -40,7 +40,7 @@ def test_device_stats_gpu_from_torch(tmp_path): class DebugLogger(CSVLogger): @rank_zero_only - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: fields = [ "allocated_bytes.all.freed", "inactive_split.all.peak", @@ -74,7 +74,7 @@ def test_device_stats_cpu(cpu_stats_mock, tmp_path, cpu_stats): CPU_METRIC_KEYS = (_CPU_VM_PERCENT, _CPU_SWAP_PERCENT, _CPU_PERCENT) class DebugLogger(CSVLogger): - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: enabled = cpu_stats is not False for f in CPU_METRIC_KEYS: has_cpu_metrics = any(f in h for h in metrics) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index a3d56bb..75f331a 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -15,7 +15,7 @@ import math import os import pickle -from typing import Optional +from typing import List, Optional from unittest import mock from unittest.mock import Mock @@ -407,7 +407,7 @@ def on_train_end(self) -> None: ) def test_multiple_early_stopping_callbacks( tmp_path, - callbacks: list[EarlyStopping], + callbacks: List[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, strategy: str, diff --git a/tests/tests_pytorch/callbacks/test_model_summary.py b/tests/tests_pytorch/callbacks/test_model_summary.py index 215176e..b42907d 100644 --- a/tests/tests_pytorch/callbacks/test_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, List, Tuple from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelSummary @@ -45,7 +45,7 @@ def test_custom_model_summary_callback_summarize(tmp_path): class CustomModelSummary(ModelSummary): @staticmethod def summarize( - summary_data: list[tuple[str, list[str]]], + summary_data: List[Tuple[str, List[str]]], total_parameters: int, trainable_parameters: int, model_size: float, diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 8d3a180..d57ac76 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -13,9 +13,8 @@ # limitations under the License. import logging import os -from contextlib import AbstractContextManager from pathlib import Path -from typing import Optional +from typing import ContextManager, Optional from unittest import mock import pytest @@ -383,5 +382,5 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str): trainer.fit(model) -def _backward_patch(trainer: Trainer) -> AbstractContextManager: +def _backward_patch(trainer: Trainer) -> ContextManager: return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 4867134..a74efba 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -43,9 +43,8 @@ def test_throughput_monitor_fit(tmp_path): ) # these timing results are meant to precisely match the `test_throughput_monitor` test in fabric timings = [0.0] + [0.5 + i for i in range(1, 6)] - with ( - mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), - mock.patch("time.perf_counter", side_effect=timings), + with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( + "time.perf_counter", side_effect=timings ): trainer.fit(model) @@ -180,9 +179,8 @@ def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_pat enable_progress_bar=False, ) timings = [0.0] + [0.5 + i for i in range(1, 11)] - with ( - mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), - mock.patch("time.perf_counter", side_effect=timings), + with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( + "time.perf_counter", side_effect=timings ): trainer.fit(model) diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index c07400e..3ae7d6b 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -92,10 +92,9 @@ def test_trainer_save_checkpoint_storage_options(tmp_path, xla_available): io_mock.assert_called_with(ANY, instance_path, storage_options=None) checkpoint_mock = Mock() - with ( - mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, - mock.patch.object(trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock) as dump_mock, - ): + with mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, mock.patch.object( + trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock + ) as dump_mock: trainer.save_checkpoint(instance_path, True) dump_mock.assert_called_with(True) save_mock.assert_called_with(checkpoint_mock, instance_path, storage_options=None) diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index ea52075..78e81c7 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -15,10 +15,10 @@ import signal import sys import threading -from concurrent.futures.process import _ExecutorManagerThread from functools import partial from http.server import SimpleHTTPRequestHandler from pathlib import Path +from typing import List from unittest.mock import Mock import lightning.fabric @@ -35,6 +35,9 @@ from tests_pytorch import _PATH_DATASETS +if sys.version_info >= (3, 9): + from concurrent.futures.process import _ExecutorManagerThread + @pytest.fixture(scope="session") def datadir(): @@ -320,7 +323,7 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: +def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: initial_size = len(items) conditions = [] filtered, skipped = 0, 0 diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 5f46815..65fccb6 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -14,7 +14,7 @@ import pickle from argparse import Namespace from dataclasses import dataclass -from typing import Any +from typing import Any, Dict from unittest import mock from unittest.mock import Mock, PropertyMock, call @@ -187,10 +187,10 @@ def validation_step(self, batch, batch_idx): return out class CustomBoringDataModule(BoringDataModule): - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> Dict[str, Any]: return {"my": "state_dict"} - def load_state_dict(self, state_dict: dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.my_state_dict = state_dict dm = CustomBoringDataModule() diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index b25b7ae..8ab6eca 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -110,10 +110,9 @@ def configure_optimizers(self): default_root_dir=tmp_path, limit_train_batches=8, limit_val_batches=1, max_epochs=1, enable_model_summary=False ) - with ( - patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, - patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam, - ): + with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple( + torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT + ) as adam: trainer.fit(model) assert sgd["step"].call_count == 4 diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index dcb3f71..004d979 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -625,9 +625,8 @@ def test_logger_sync_dist(distributed_env, log_val): else nullcontext() ) - with ( - warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"), - patch_ctx, - ): + with warning_ctx( + PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`" + ), patch_ctx: value = _ResultCollection._get_cache(result_metric, on_step=False) assert value == 0.5 diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 014fb37..9b1d4ec 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -16,8 +16,7 @@ import random import time import urllib.request -from collections.abc import Sequence -from typing import Optional +from typing import Optional, Sequence, Tuple import torch from torch import Tensor @@ -64,7 +63,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> tuple[Tensor, int]: + def __getitem__(self, idx: int) -> Tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index dcdd504..4d74e04 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -14,7 +14,7 @@ import pickle from argparse import Namespace from copy import deepcopy -from typing import Any, Optional +from typing import Any, Dict, Optional from unittest.mock import patch import numpy as np @@ -252,12 +252,12 @@ def __init__(self, param_one, param_two): @patch("lightning.pytorch.loggers.tensorboard.TensorBoardLogger.log_hyperparams") def test_log_hyperparams_key_collision(_, tmp_path): class TestModel(BoringModel): - def __init__(self, hparams: dict[str, Any]) -> None: + def __init__(self, hparams: Dict[str, Any]) -> None: super().__init__() self.save_hyperparameters(hparams) class TestDataModule(BoringDataModule): - def __init__(self, hparams: dict[str, Any]) -> None: + def __init__(self, hparams: Dict[str, Any]) -> None: super().__init__() self.save_hyperparameters(hparams) diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index 2fb04d0..0ea6290 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator, Mapping from contextlib import nullcontext -from typing import Generic, TypeVar +from typing import Dict, Generic, Iterator, Mapping, TypeVar import pytest import torch @@ -50,8 +49,8 @@ def test_closure_result_apply_accumulation(): class OutputMapping(Generic[T], Mapping[str, T]): - def __init__(self, d: dict[str, T]) -> None: - self.d: dict[str, T] = d + def __init__(self, d: Dict[str, T]) -> None: + self.d: Dict[str, T] = d def __iter__(self) -> Iterator[str]: return iter(self.d) diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 75b25e3..763a6de 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import Counter -from collections.abc import Iterator -from typing import Any +from typing import Any, Iterator import pytest import torch diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 1820ca3..8d94275 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections.abc import Iterator from copy import deepcopy from dataclasses import dataclass -from typing import Any +from typing import Any, Dict, Iterator from unittest.mock import ANY, Mock import pytest @@ -88,10 +87,10 @@ def advance(self) -> None: self.outputs.append(value) - def state_dict(self) -> dict: + def state_dict(self) -> Dict: return {"iteration_count": self.iteration_count, "outputs": self.outputs} - def load_state_dict(self, state_dict: dict) -> None: + def load_state_dict(self, state_dict: Dict) -> None: self.iteration_count = state_dict["iteration_count"] self.outputs = state_dict["outputs"] @@ -141,10 +140,10 @@ def advance(self) -> None: return loop.run() - def on_save_checkpoint(self) -> dict: + def on_save_checkpoint(self) -> Dict: return {"a": self.a} - def on_load_checkpoint(self, state_dict: dict) -> None: + def on_load_checkpoint(self, state_dict: Dict) -> None: self.a = state_dict["a"] trainer = Trainer() diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 64f70b1..fe7e3fb 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -15,9 +15,8 @@ import logging as log import os import pickle -from collections.abc import Mapping from copy import deepcopy -from typing import Generic, TypeVar +from typing import Generic, Mapping, TypeVar import cloudpickle import pytest diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 3e2fba5..29eb6d6 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable +from typing import Iterable import pytest import torch diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 58baa47..185a767 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Optional +from typing import Any, Dict, Optional from unittest.mock import MagicMock, Mock import torch @@ -27,10 +27,10 @@ class CustomCheckpointIO(CheckpointIO): - def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: torch.save(checkpoint, path) - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: return torch.load(path, weights_only=True) def remove_checkpoint(self, path: _PATH) -> None: diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index ec4dd88..7c883c4 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,3 +1,5 @@ +from typing import Dict + import pytest import torch from lightning.pytorch import Trainer @@ -19,7 +21,7 @@ def serialize(x): return {"x": deserialize}, {"output": serialize} - def serve_step(self, x: Tensor) -> dict[str, Tensor]: + def serve_step(self, x: Tensor) -> Dict[str, Tensor]: assert torch.equal(x, torch.arange(32, dtype=torch.float)) return {"output": torch.tensor([0, 1])} diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 394d827..b0462c0 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -209,13 +209,10 @@ def test_memory_sharing_disabled(tmp_path): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with ( - mock.patch( - "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), - pytest.raises(RuntimeError, match="requires that your script guards the main"), - ): + with mock.patch( + "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), pytest.raises(RuntimeError, match="requires that your script guards the main"): launcher.launch(function=Mock()) diff --git a/tests/tests_pytorch/strategies/test_custom_strategy.py b/tests/tests_pytorch/strategies/test_custom_strategy.py index 347dacb..7f7d018 100644 --- a/tests/tests_pytorch/strategies/test_custom_strategy.py +++ b/tests/tests_pytorch/strategies/test_custom_strategy.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections.abc import Mapping -from typing import Any +from typing import Any, Mapping import pytest import torch diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 73697ea..be9428f 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -15,7 +15,7 @@ import json import os from re import escape -from typing import Any +from typing import Any, Dict from unittest import mock from unittest.mock import ANY, Mock @@ -48,7 +48,7 @@ def configure_model(self) -> None: if self.layer is None: self.layer = torch.nn.Linear(32, 2) - def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.configure_model() @@ -73,7 +73,7 @@ def configure_model(self) -> None: if self.layer is None: self.layer = torch.nn.Linear(32, 2) - def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.configure_model() @property @@ -623,7 +623,7 @@ def configure_optimizers(self): lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] - def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if not hasattr(self, "model"): self.configure_model() diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 2aee68f..aec01b8 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -444,12 +444,9 @@ def __init__(self): strategy._parallel_devices = [torch.device("cuda", 0)] strategy._lightning_module = model strategy._process_group = Mock() - with ( - mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), - mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock, - ): + with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock: wrapped = strategy._setup_model(model) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 57d2739..9dcbcc8 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -78,10 +78,26 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): return model +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + +@pytest.fixture() +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class TemplateModel(LightningModule): - def __init__(self): + def __init__(self, compile=False): super().__init__() self.model = FeedForward() + self._compile = compile def training_step(self, batch): output = self.model(batch) @@ -98,21 +114,30 @@ def configure_optimizers(self): class FSDP2Model(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class FSDP2TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -169,7 +194,11 @@ def configure_model(self): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(TensorParallelModel): @@ -204,13 +233,17 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(FSDP2TensorParallelModel): @@ -261,13 +294,13 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_modules_without_parameters(tmp_path): +def test_modules_without_parameters(distributed, tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" class MetricsModel(TensorParallelModel): @@ -306,7 +339,11 @@ def training_step(self, batch): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype, tmp_path): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_module_init_context(distributed, compile, precision, expected_dtype, tmp_path): """Test that the module under the init-context gets moved to the right device and dtype.""" class Model(FSDP2Model): @@ -329,7 +366,7 @@ def _run_setup_assertions(empty_init, expected_device): logger=False, ) with trainer.init_module(empty_init=empty_init): - model = Model() + model = Model(compile=compile) # The model is on the CPU/meta-device until after `ModelParallelStrategy.setup()` assert model.model.w1.weight.device == expected_device @@ -345,7 +382,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("save_distributed_checkpoint", [True, False]) -def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): +def test_strategy_state_dict(distributed, tmp_path, save_distributed_checkpoint): """Test that the strategy returns the correct state dict of the LightningModule.""" model = FSDP2Model() correct_state_dict = model.state_dict() # State dict before wrapping @@ -378,7 +415,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_full_state_checkpoint_into_regular_model(tmp_path): +def test_load_full_state_checkpoint_into_regular_model(distributed, tmp_path): """Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model.""" # Save a regular full-state checkpoint from a distributed model @@ -420,7 +457,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_standard_checkpoint_into_distributed_model(tmp_path): +def test_load_standard_checkpoint_into_distributed_model(distributed, tmp_path): """Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model.""" # Save a regular DDP checkpoint @@ -461,7 +498,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_load_sharded_state_dict(tmp_path): +def test_save_load_sharded_state_dict(distributed, tmp_path): """Test saving and loading with the distributed state dict format.""" class CheckpointModel(FSDP2Model): diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index d106a05..cdec778 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -20,7 +20,7 @@ from contextlib import ExitStack, contextmanager, redirect_stdout from io import StringIO from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union from unittest import mock from unittest.mock import ANY @@ -127,7 +127,7 @@ def _model_builder(model_param: int) -> Model: def _trainer_builder( - limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[list[Callback], Callback]] = None + limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None ) -> Trainer: return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) @@ -409,9 +409,8 @@ def test_lightning_cli_config_and_subclass_mode(cleandir): with open(config_path, "w") as f: f.write(yaml.dump(input_config)) - with ( - mock.patch("sys.argv", ["any.py", "--config", config_path]), - mock_subclasses(LightningDataModule, DataDirDataModule), + with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses( + LightningDataModule, DataDirDataModule ): cli = LightningCLI( BoringModel, @@ -462,12 +461,9 @@ def test_lightning_cli_help(): cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"] out = StringIO() - with ( - mock.patch("sys.argv", cli_args), - redirect_stdout(out), - mock_subclasses(LightningDataModule, DataDirDataModule), - pytest.raises(SystemExit), - ): + with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses( + LightningDataModule, DataDirDataModule + ), pytest.raises(SystemExit): any_model_any_data_cli() assert ("--data.data_dir" in out.getvalue()) or ("--data.init_args.data_dir" in out.getvalue()) @@ -478,8 +474,8 @@ def test_lightning_cli_print_config(): "any.py", "predict", "--seed_everything=1234", - "--model=lightning.pytorch.demos.BoringModel", - "--data=lightning.pytorch.demos.BoringDataModule", + "--model=lightning.pytorch.demos.boring_classes.BoringModel", + "--data=lightning.pytorch.demos.boring_classes.BoringDataModule", "--print_config", ] out = StringIO() @@ -492,8 +488,8 @@ def test_lightning_cli_print_config(): outval = yaml.safe_load(text) assert outval["seed_everything"] == 1234 - assert outval["model"]["class_path"] == "lightning.pytorch.demos.BoringModel" - assert outval["data"]["class_path"] == "lightning.pytorch.demos.BoringDataModule" + assert outval["model"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringModel" + assert outval["data"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringDataModule" assert outval["ckpt_path"] is None @@ -526,7 +522,7 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) def test_lightning_cli_torch_modules(cleandir): class TestModule(BoringModel): - def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[torch.nn.Module]] = None): + def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): super().__init__() self.activation = activation self.transform = transform @@ -613,9 +609,8 @@ def on_fit_start(self): def test_cli_distributed_save_config_callback(cleandir, logger, strategy): from torch.multiprocessing import ProcessRaisedException - with ( - mock.patch("sys.argv", ["any.py", "fit"]), - pytest.raises((MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"), + with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( + (MisconfigurationException, ProcessRaisedException), match=r"Error on fit start" ): LightningCLI( EarlyExitTestModel, @@ -715,14 +710,12 @@ def train_dataloader(self): ... from lightning.pytorch.trainer.configuration_validator import __verify_train_val_loop_configuration - with ( - mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), - mock.patch("lightning.pytorch.Trainer._run_stage") as run, - mock.patch( - "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", - wraps=__verify_train_val_loop_configuration, - ) as verify, - ): + with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch( + "lightning.pytorch.Trainer._run_stage" + ) as run, mock.patch( + "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", + wraps=__verify_train_val_loop_configuration, + ) as verify: cli = LightningCLI(BoringModel) run.assert_called_once() verify.assert_called_once_with(cli.trainer, cli.model) @@ -1095,18 +1088,15 @@ def __init__(self, foo, bar=5): @_xfail_python_ge_3_11_9 def test_lightning_cli_model_short_arguments(): - with ( - mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), - mock.patch("lightning.pytorch.Trainer._fit_impl") as run, - mock_subclasses(LightningModule, BoringModel, TestModel), - ): + with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch( + "lightning.pytorch.Trainer._fit_impl" + ) as run, mock_subclasses(LightningModule, BoringModel, TestModel): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) - with ( - mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), - mock_subclasses(LightningModule, BoringModel, TestModel), + with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses( + LightningModule, BoringModel, TestModel ): cli = LightningCLI(run=False) assert isinstance(cli.model, TestModel) @@ -1124,18 +1114,15 @@ def __init__(self, foo, bar=5): @_xfail_python_ge_3_11_9 def test_lightning_cli_datamodule_short_arguments(): # with set model - with ( - mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), - mock.patch("lightning.pytorch.Trainer._fit_impl") as run, - mock_subclasses(LightningDataModule, BoringDataModule), - ): + with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch( + "lightning.pytorch.Trainer._fit_impl" + ) as run, mock_subclasses(LightningDataModule, BoringDataModule): cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY) - with ( - mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), - mock_subclasses(LightningDataModule, MyDataModule), + with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses( + LightningDataModule, MyDataModule ): cli = LightningCLI(BoringModel, run=False) assert isinstance(cli.datamodule, MyDataModule) @@ -1143,22 +1130,17 @@ def test_lightning_cli_datamodule_short_arguments(): assert cli.datamodule.bar == 5 # with configurable model - with ( - mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), - mock.patch("lightning.pytorch.Trainer._fit_impl") as run, - mock_subclasses(LightningModule, BoringModel), - mock_subclasses(LightningDataModule, BoringDataModule), - ): + with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch( + "lightning.pytorch.Trainer._fit_impl" + ) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY) - with ( - mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), - mock_subclasses(LightningModule, BoringModel), - mock_subclasses(LightningDataModule, MyDataModule), - ): + with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses( + LightningModule, BoringModel + ), mock_subclasses(LightningDataModule, MyDataModule): cli = LightningCLI(run=False) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, MyDataModule) @@ -1325,10 +1307,9 @@ def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): def test_lightning_cli_config_with_subcommand(): config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with ( - mock.patch("sys.argv", ["any.py", f"--config={config}"]), - mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, - ): + with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch( + "lightning.pytorch.Trainer.test", autospec=True + ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1341,10 +1322,9 @@ def test_lightning_cli_config_before_subcommand(): "test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}, } - with ( - mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), - mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, - ): + with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch( + "lightning.pytorch.Trainer.test", autospec=True + ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") @@ -1354,10 +1334,9 @@ def test_lightning_cli_config_before_subcommand(): assert save_config_callback.config.trainer.limit_test_batches == 1 assert save_config_callback.parser.subcommand == "test" - with ( - mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), - mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, - ): + with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch( + "lightning.pytorch.Trainer.validate", autospec=True + ) as validate_mock: cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1372,19 +1351,17 @@ def test_lightning_cli_config_before_subcommand_two_configs(): config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}} config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with ( - mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), - mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, - ): + with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch( + "lightning.pytorch.Trainer.test", autospec=True + ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 - with ( - mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), - mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, - ): + with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch( + "lightning.pytorch.Trainer.validate", autospec=True + ) as validate_mock: cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1393,10 +1370,9 @@ def test_lightning_cli_config_before_subcommand_two_configs(): def test_lightning_cli_config_after_subcommand(): config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"} - with ( - mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), - mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, - ): + with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch( + "lightning.pytorch.Trainer.test", autospec=True + ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1406,10 +1382,9 @@ def test_lightning_cli_config_after_subcommand(): def test_lightning_cli_config_before_and_after_subcommand(): config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"} - with ( - mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), - mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, - ): + with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch( + "lightning.pytorch.Trainer.test", autospec=True + ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") @@ -1431,19 +1406,17 @@ def test_lightning_cli_parse_kwargs_with_subcommands(cleandir): "validate": {"default_config_files": [str(validate_config_path)]}, } - with ( - mock.patch("sys.argv", ["any.py", "fit"]), - mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, - ): + with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( + "lightning.pytorch.Trainer.fit", autospec=True + ) as fit_mock: cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.trainer.limit_train_batches == 2 assert cli.trainer.limit_val_batches == 1.0 - with ( - mock.patch("sys.argv", ["any.py", "validate"]), - mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, - ): + with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch( + "lightning.pytorch.Trainer.validate", autospec=True + ) as validate_mock: cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) validate_mock.assert_called() assert cli.trainer.limit_train_batches == 1.0 @@ -1461,10 +1434,9 @@ def __init__(self, foo: int, *args, **kwargs): config_path.write_text(str(config)) parser_kwargs = {"default_config_files": [str(config_path)]} - with ( - mock.patch("sys.argv", ["any.py", "fit"]), - mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, - ): + with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( + "lightning.pytorch.Trainer.fit", autospec=True + ) as fit_mock: cli = LightningCLI(Model, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.model.foo == 123 diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 9e947e0..65c5777 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -15,7 +15,7 @@ import os import sys from contextlib import nullcontext -from typing import Any +from typing import Any, Dict from unittest import mock from unittest.mock import Mock @@ -178,7 +178,7 @@ class Accel(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> dict[str, Any]: + def get_device_stats(self, device: torch.device) -> Dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index ca5690e..a820a3d 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sized from re import escape +from typing import Sized from unittest import mock from unittest.mock import Mock diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index af7cecd..41d9301 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from unittest.mock import Mock import lightning.pytorch as pl @@ -38,7 +38,7 @@ def __init__(self): def experiment(self) -> Any: return self.exp - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): self.logs.update(metrics) def version(self) -> Union[int, str]: @@ -144,7 +144,7 @@ def __init__(self): self.buffer = {} self.logs = {} - def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: self.buffer.update(metrics) def finalize(self, status: str) -> None: diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index ac660b6..451557d 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -592,10 +592,9 @@ def configure_optimizers(self): limit_train_batches=limit_train_batches, limit_val_batches=0, ) - with ( - mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, - mock.patch.object(torch.optim.lr_scheduler.StepLR, "step") as mock_method_step, - ): + with mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, mock.patch.object( + torch.optim.lr_scheduler.StepLR, "step" + ) as mock_method_step: trainer.fit(model) assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index d66f3aa..8946fb4 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1887,9 +1887,8 @@ def training_step(self, batch, batch_idx): model = NanModel() trainer = Trainer(default_root_dir=tmp_path, detect_anomaly=True) - with ( - pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), - pytest.warns(UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*"), + with pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), pytest.warns( + UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*" ): trainer.fit(model) @@ -2068,9 +2067,8 @@ def on_fit_start(self): raise exception trainer = Trainer(default_root_dir=tmp_path) - with ( - mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, - suppress(Exception, SystemExit), + with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress( + Exception, SystemExit ): trainer.fit(ExceptionModel()) on_exception_mock.assert_called_once_with(exception) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 43a146c..74f5c13 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -13,8 +13,7 @@ # limitations under the License. import math import pickle -from collections.abc import Sequence -from typing import Any, NamedTuple, get_args +from typing import Any, NamedTuple, Sequence, get_args from unittest.mock import Mock import pytest diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2ef1ecd..1bdac61 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -33,9 +33,8 @@ def test_upgrade_checkpoint_file_missing(tmp_path, caplog): # path to non-empty directory, but no checkpoints with matching extension file.touch() - with ( - mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), - caplog.at_level(logging.ERROR), + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), caplog.at_level( + logging.ERROR ): with pytest.raises(SystemExit): upgrade_main()