Skip to content

Commit 5e14e1f

Browse files
author
Shallow Copy Bot
committed
FP8 + FSDP2 + torch.compile examples for PyTorch Lightning and Fabric
Original PR #20440 by lantiga Original: Lightning-AI/pytorch-lightning#20440
1 parent 04b95cf commit 5e14e1f

File tree

230 files changed

+1436
-1509
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

230 files changed

+1436
-1509
lines changed

.actions/assistant.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
import shutil
1919
import tempfile
2020
import urllib.request
21-
from collections.abc import Iterable, Iterator, Sequence
2221
from itertools import chain
2322
from os.path import dirname, isfile
2423
from pathlib import Path
25-
from typing import Any, Optional
24+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
2625

2726
from packaging.requirements import Requirement
2827
from packaging.version import Version
@@ -128,7 +127,7 @@ def _parse_requirements(lines: Iterable[str]) -> Iterator[_RequirementWithCommen
128127
pip_argument = None
129128

130129

131-
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]:
130+
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]:
132131
"""Loading requirements from a file.
133132
134133
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
@@ -223,7 +222,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
223222
fp.writelines([ln + os.linesep for ln in requires] + [os.linesep])
224223

225224

226-
def _retrieve_files(directory: str, *ext: str) -> list[str]:
225+
def _retrieve_files(directory: str, *ext: str) -> List[str]:
227226
all_files = []
228227
for root, _, files in os.walk(directory):
229228
for fname in files:
@@ -233,7 +232,7 @@ def _retrieve_files(directory: str, *ext: str) -> list[str]:
233232
return all_files
234233

235234

236-
def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]:
235+
def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]:
237236
"""Replace imports of standalone package to lightning.
238237
239238
>>> lns = [
@@ -321,7 +320,7 @@ def copy_replace_imports(
321320
fo.writelines(lines)
322321

323322

324-
def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> None:
323+
def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None:
325324
"""Create a mirror package with adjusted imports."""
326325
# replace imports and copy the code
327326
mapping = package_mapping.copy()

.github/workflows/_legacy-checkpoints.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
- uses: actions/setup-python@v5
6161
with:
6262
# Python version here needs to be supported by all PL versions listed in back-compatible-versions.txt.
63-
python-version: "3.9"
63+
python-version: 3.8
6464

6565
- name: Install PL from source
6666
env:

.github/workflows/call-clear-cache.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ on:
2323
jobs:
2424
cron-clear:
2525
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
26-
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
26+
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
2727
with:
2828
scripts-ref: v0.11.8
2929
dry-run: ${{ github.event_name == 'pull_request' }}
@@ -32,7 +32,7 @@ jobs:
3232

3333
direct-clear:
3434
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
35-
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
35+
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
3636
with:
3737
scripts-ref: v0.11.8
3838
dry-run: ${{ github.event_name == 'pull_request' }}

.github/workflows/ci-check-md-links.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ on:
1414

1515
jobs:
1616
check-md-links:
17-
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
17+
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
1818
with:
1919
config-file: ".github/markdown-links-config.json"
2020
base-branch: "master"

.github/workflows/ci-schema.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
check:
11-
uses: Lightning-AI/utilities/.github/workflows/[email protected].9
11+
uses: Lightning-AI/utilities/.github/workflows/[email protected].8
1212
with:
1313
# skip azure due to the wrong schema file by MSFT
1414
# https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ repos:
7474
hooks:
7575
# try to fix what is possible
7676
- id: ruff
77-
args: ["--fix", "--unsafe-fixes"]
77+
args: ["--fix"]
7878
# perform formatting updates
7979
- id: ruff-format
8080
# validate if all is fine with preview mode

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ ignore-words-list = "te, compiletime"
4444

4545
[tool.ruff]
4646
line-length = 120
47-
target-version = "py39"
47+
target-version = "py38"
4848
# Exclude a variety of commonly ignored directories.
4949
exclude = [
5050
".git",

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@
4545
import logging
4646
import os
4747
import tempfile
48-
from collections.abc import Generator, Mapping
4948
from importlib.util import module_from_spec, spec_from_file_location
5049
from types import ModuleType
51-
from typing import Optional
50+
from typing import Generator, Mapping, Optional
5251

5352
import setuptools
5453
import setuptools.command.egg_info

src/lightning/__setup__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from importlib.util import module_from_spec, spec_from_file_location
44
from pathlib import Path
55
from types import ModuleType
6-
from typing import Any
6+
from typing import Any, Dict
77

88
from setuptools import find_namespace_packages
99

@@ -26,7 +26,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
2626
_ASSISTANT = _load_py_module(name="assistant", location=os.path.join(_PROJECT_ROOT, ".actions", "assistant.py"))
2727

2828

29-
def _prepare_extras() -> dict[str, Any]:
29+
def _prepare_extras() -> Dict[str, Any]:
3030
# https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras
3131
# Define package extras. These are only installed if you specify them.
3232
# From remote, use like `pip install "lightning[dev, docs]"`
@@ -63,7 +63,7 @@ def _prepare_extras() -> dict[str, Any]:
6363
return extras
6464

6565

66-
def _setup_args() -> dict[str, Any]:
66+
def _setup_args() -> Dict[str, Any]:
6767
about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py"))
6868
version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py"))
6969
long_description = _ASSISTANT.load_readme_description(

src/lightning/fabric/accelerators/cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Union
14+
from typing import List, Union
1515

1616
import torch
1717
from typing_extensions import override
@@ -45,7 +45,7 @@ def parse_devices(devices: Union[int, str]) -> int:
4545

4646
@staticmethod
4747
@override
48-
def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]:
48+
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
4949
"""Gets parallel devices for the Accelerator."""
5050
devices = _parse_cpu_cores(devices)
5151
return [torch.device("cpu")] * devices

0 commit comments

Comments
 (0)