Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .actions/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
import shutil
import tempfile
import urllib.request
from collections.abc import Iterable, Iterator, Sequence
from itertools import chain
from os.path import dirname, isfile
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple

from packaging.requirements import Requirement
from packaging.version import Version
Expand Down Expand Up @@ -128,7 +127,7 @@ def _parse_requirements(lines: Iterable[str]) -> Iterator[_RequirementWithCommen
pip_argument = None


def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]:
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]:
"""Loading requirements from a file.

>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
Expand Down Expand Up @@ -223,7 +222,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
fp.writelines([ln + os.linesep for ln in requires] + [os.linesep])


def _retrieve_files(directory: str, *ext: str) -> list[str]:
def _retrieve_files(directory: str, *ext: str) -> List[str]:
all_files = []
for root, _, files in os.walk(directory):
for fname in files:
Expand All @@ -233,7 +232,7 @@ def _retrieve_files(directory: str, *ext: str) -> list[str]:
return all_files


def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]:
def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]:
"""Replace imports of standalone package to lightning.

>>> lns = [
Expand Down Expand Up @@ -321,7 +320,7 @@ def copy_replace_imports(
fo.writelines(lines)


def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> None:
def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None:
"""Create a mirror package with adjusted imports."""
# replace imports and copy the code
mapping = package_mapping.copy()
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_legacy-checkpoints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- uses: actions/setup-python@v5
with:
# Python version here needs to be supported by all PL versions listed in back-compatible-versions.txt.
python-version: "3.9"
python-version: 3.8

- name: Install PL from source
env:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/call-clear-cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ on:
jobs:
cron-clear:
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
with:
scripts-ref: v0.11.8
dry-run: ${{ github.event_name == 'pull_request' }}
Expand All @@ -32,7 +32,7 @@ jobs:

direct-clear:
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
with:
scripts-ref: v0.11.8
dry-run: ${{ github.event_name == 'pull_request' }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-check-md-links.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:

jobs:
check-md-links:
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
with:
config-file: ".github/markdown-links-config.json"
base-branch: "master"
2 changes: 1 addition & 1 deletion .github/workflows/ci-schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:

jobs:
check:
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
with:
# skip azure due to the wrong schema file by MSFT
# https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ repos:
hooks:
# try to fix what is possible
- id: ruff
args: ["--fix", "--unsafe-fixes"]
args: ["--fix"]
# perform formatting updates
- id: ruff-format
# validate if all is fine with preview mode
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ignore-words-list = "te, compiletime"

[tool.ruff]
line-length = 120
target-version = "py39"
target-version = "py38"
# Exclude a variety of commonly ignored directories.
exclude = [
".git",
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@
import logging
import os
import tempfile
from collections.abc import Generator, Mapping
from importlib.util import module_from_spec, spec_from_file_location
from types import ModuleType
from typing import Optional
from typing import Generator, Mapping, Optional

import setuptools
import setuptools.command.egg_info
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType
from typing import Any
from typing import Any, Dict

from setuptools import find_namespace_packages

Expand All @@ -26,7 +26,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
_ASSISTANT = _load_py_module(name="assistant", location=os.path.join(_PROJECT_ROOT, ".actions", "assistant.py"))


def _prepare_extras() -> dict[str, Any]:
def _prepare_extras() -> Dict[str, Any]:
# https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras
# Define package extras. These are only installed if you specify them.
# From remote, use like `pip install "lightning[dev, docs]"`
Expand Down Expand Up @@ -63,7 +63,7 @@ def _prepare_extras() -> dict[str, Any]:
return extras


def _setup_args() -> dict[str, Any]:
def _setup_args() -> Dict[str, Any]:
about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py"))
version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py"))
long_description = _ASSISTANT.load_readme_description(
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import List, Union

import torch
from typing_extensions import override
Expand Down Expand Up @@ -45,7 +45,7 @@ def parse_devices(devices: Union[int, str]) -> int:

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
Expand Down
10 changes: 5 additions & 5 deletions src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
from typing import Optional, Union
from typing import List, Optional, Union

import torch
from typing_extensions import override
Expand Down Expand Up @@ -43,15 +43,15 @@ def teardown(self) -> None:

@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]:
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

return _parse_gpu_ids(devices, include_cuda=True)

@staticmethod
@override
def get_parallel_devices(devices: list[int]) -> list[torch.device]:
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

Expand All @@ -76,7 +76,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
)


def find_usable_cuda_devices(num_devices: int = -1) -> list[int]:
def find_usable_cuda_devices(num_devices: int = -1) -> List[int]:
"""Returns a list of all available and usable CUDA GPU devices.

A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function
Expand Down Expand Up @@ -129,7 +129,7 @@ def find_usable_cuda_devices(num_devices: int = -1) -> list[int]:
return available_devices


def _get_all_visible_cuda_devices() -> list[int]:
def _get_all_visible_cuda_devices() -> List[int]:
"""Returns a list of all visible CUDA GPU devices.

Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import platform
from functools import lru_cache
from typing import Optional, Union
from typing import List, Optional, Union

import torch
from typing_extensions import override
Expand Down Expand Up @@ -46,15 +46,15 @@ def teardown(self) -> None:

@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]:
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

return _parse_gpu_ids(devices, include_mps=True)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]:
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
parsed_devices = MPSAccelerator.parse_devices(devices)
assert parsed_devices is not None
Expand Down Expand Up @@ -84,7 +84,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
)


def _get_all_available_mps_gpus() -> list[int]:
def _get_all_available_mps_gpus() -> List[int]:
"""
Returns:
A list of all available MPS GPUs
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/accelerators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, List, Optional

from typing_extensions import override

Expand Down Expand Up @@ -68,7 +68,7 @@ def register(
if name in self and not override:
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.")

data: dict[str, Any] = {}
data: Dict[str, Any] = {}

data["description"] = description
data["init_params"] = init_params
Expand Down Expand Up @@ -107,7 +107,7 @@ def remove(self, name: str) -> None:
"""Removes the registered accelerator by name."""
self.pop(name)

def available_accelerators(self) -> list[str]:
def available_accelerators(self) -> List[str]:
"""Returns a list of registered accelerators."""
return list(self.keys())

Expand Down
17 changes: 5 additions & 12 deletions src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Union
from typing import Any, List, Union

import torch
from lightning_utilities.core.imports import RequirementCache
Expand Down Expand Up @@ -47,13 +47,13 @@ def teardown(self) -> None:

@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]:
"""Accelerator device parsing logic."""
return _parse_tpu_devices(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_tpu_devices(devices)
if isinstance(devices, int):
Expand Down Expand Up @@ -102,27 +102,20 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
# PJRT support requires this minimum version
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")


def _using_pjrt() -> bool:
# `using_pjrt` is removed in torch_xla 2.5
if _XLA_GREATER_EQUAL_2_5:
from torch_xla import runtime as xr

return xr.device_type() is not None
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as xr

return xr.using_pjrt()

from torch_xla.experimental import pjrt

return pjrt.using_pjrt()


def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
def _parse_tpu_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]:
"""Parses the TPU devices given in the format as accepted by the
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.

Expand Down Expand Up @@ -159,7 +152,7 @@ def _check_tpu_devices_valid(devices: object) -> None:
)


def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
def _parse_tpu_devices_str(devices: str) -> Union[int, List[int]]:
devices = devices.strip()
try:
return int(devices)
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import subprocess
import sys
from argparse import Namespace
from typing import Any, Optional
from typing import Any, List, Optional

import torch
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -39,7 +39,7 @@
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")


def _get_supported_strategies() -> list[str]:
def _get_supported_strategies() -> List[str]:
"""Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the
CLI or ones that require further configuration by the user."""
available_strategies = STRATEGY_REGISTRY.available_strategies()
Expand Down Expand Up @@ -221,7 +221,7 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
return len(parsed_devices) if parsed_devices is not None else 0


def _torchrun_launch(args: Namespace, script_args: list[str]) -> None:
def _torchrun_launch(args: Namespace, script_args: List[str]) -> None:
"""This will invoke `torchrun` programmatically to launch the given script in new processes."""
import torch.distributed.run as torchrun

Expand All @@ -242,7 +242,7 @@ def _torchrun_launch(args: Namespace, script_args: list[str]) -> None:
torchrun.main(torchrun_args)


def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
def main(args: Namespace, script_args: Optional[List[str]] = None) -> None:
_set_env_variables(args)
_torchrun_launch(args, script_args or [])

Expand Down
Loading
Loading