Skip to content

Commit 714bc69

Browse files
authored
Merge branch 'main' into model.dump_fix
2 parents 7af9553 + 8fdd0a1 commit 714bc69

File tree

20 files changed

+292
-260
lines changed

20 files changed

+292
-260
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ llmcompressor/pruned
2323
tensorboard/*
2424
onnx/*
2525
repos/*
26+
# managed by setuptools-scm
27+
/src/llmcompressor/version.py
2628

2729
### Python template
2830
# Byte-compiled / optimized / DLL files

DEVELOPING.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Here are some details to get started.
1010
**Development Installation**
1111

1212
```bash
13-
git https://github.com/vllm-project/llm-compressor
13+
git clone https://github.com/vllm-project/llm-compressor
1414
cd llm-compressor
1515
python3 -m pip install -e "./[dev]"
1616
```
@@ -39,20 +39,20 @@ File any error found before changes as an Issue and fix any errors found after m
3939

4040
## GitHub Workflow
4141

42-
1. Fork the `llmcompressor` repository into your GitHub account: https://github.com/vllm-project/llm-compressor.
42+
1. Fork the `llm-compressor` repository into your GitHub account: https://github.com/vllm-project/llm-compressor.
4343

4444
2. Clone your fork of the GitHub repository, replacing `<username>` with your GitHub username.
4545

4646
Use ssh (recommended):
4747

4848
```bash
49-
git clone git@github.com:<username>/llmcompressor.git
49+
git clone git@github.com:<username>/llm-compressor.git
5050
```
5151

5252
Or https:
5353

5454
```bash
55-
git clone https://github.com/<username>/llmcompressor.git
55+
git clone https://github.com/<username>/llm-compressor.git
5656
```
5757

5858
3. Add a remote to keep up with upstream changes.
@@ -70,7 +70,7 @@ File any error found before changes as an Issue and fix any errors found after m
7070
4. Create a feature branch to work in.
7171

7272
```bash
73-
git checkout -b feature-xxx remotes/upstream/main
73+
git checkout -b feature-xxx upstream/main
7474
```
7575

7676
5. Work in your feature branch.
@@ -104,10 +104,10 @@ File any error found before changes as an Issue and fix any errors found after m
104104
Go to your fork main page
105105

106106
```bash
107-
https://github.com/<username>/llmcompressor
107+
https://github.com/<username>/llm-compressor
108108
```
109109

110-
If you recently pushed your changes GitHub will automatically pop up a `Compare & pull request` button for any branches you recently pushed to. If you click that button it will automatically offer you to submit your pull-request to the `TODO/llmcompressor` repository.
110+
If you recently pushed your changes GitHub will automatically pop up a `Compare & pull request` button for any branches you recently pushed to. If you click that button it will automatically offer you to submit your pull-request to the `vllm-project/llm-compressor` repository.
111111

112112
- Give your pull-request a meaningful title.
113113
You'll know your title is properly formatted once the `Semantic Pull Request` GitHub check

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
BUILDDIR := $(PWD)
2-
CHECKDIRS := src tests utils examples setup.py
2+
CHECKDIRS := src tests examples setup.py
33
DOCDIR := docs
44

55
BUILD_ARGS := # set nightly to build nightly release
66

7+
# refer to setup.py for allowed values for BUILD_TYPE
8+
BUILD_TYPE?=dev
9+
export BUILD_TYPE
10+
711
TARGETS := "" # targets for running pytests: deepsparse,keras,onnx,pytorch,pytorch_models,export,pytorch_datasets,tensorflow_v1,tensorflow_v1_models,tensorflow_v1_datasets
812
PYTEST_ARGS ?= ""
913
ifneq ($(findstring transformers,$(TARGETS)),transformers)
@@ -37,6 +41,7 @@ test:
3741
pytest tests $(PYTEST_ARGS)
3842

3943
# creates wheel file
44+
.PHONY: build
4045
build:
4146
python3 setup.py sdist bdist_wheel $(BUILD_ARGS)
4247

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
[build-system]
2-
requires = ["setuptools", "wheel"]
2+
requires = ["setuptools", "wheel", "setuptools_scm>8"]
33
build-backend = "setuptools.build_meta"
44

5+
[tool.setuptools_scm]
6+
version_file = "src/llmcompressor/version.py"
7+
58
[tool.black]
69
line-length = 88
710
target-version = ['py38']
811

912
[tool.isort]
1013
profile = "black"
11-
skip = ["src/llmcompressor/transformers/tracing/"]
14+
skip = ["src/llmcompressor/transformers/tracing/", "src/llmcompressor/version.py"]
1215

1316
[tool.mypy]
1417
files = "src/guidellm"

setup.py

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,89 @@
22
import sys
33

44
from setuptools import find_packages, setup
5+
from setuptools_scm import ScmVersion
56

6-
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
7-
from utils.version_extractor import extract_version_info # noqa isort:skip
7+
# Set the build type using an environment variable to give us
8+
# different package names based on the reason for the build.
9+
VALID_BUILD_TYPES = {"release", "nightly", "dev"}
10+
BUILD_TYPE = os.environ.get("BUILD_TYPE", "dev")
11+
if BUILD_TYPE not in VALID_BUILD_TYPES:
12+
raise ValueError(
13+
f"Unsupported build type {BUILD_TYPE!r}, must be one of {VALID_BUILD_TYPES}"
14+
)
815

9-
# load version info for the package
10-
package_path = os.path.join(
11-
os.path.dirname(os.path.realpath(__file__)), "src", "llmcompressor"
12-
)
13-
version_info = extract_version_info(package_path)
1416

15-
if version_info.build_type == "release":
16-
package_name = "llmcompressor"
17-
elif version_info.build_type == "dev":
18-
package_name = "llmcompressor-dev"
19-
elif version_info.build_type == "nightly":
20-
package_name = "llmcompressor-nightly"
21-
else:
22-
raise ValueError(f"Unsupported build type {version_info.build_type}")
17+
def version_func(version: ScmVersion) -> str:
18+
from setuptools_scm.version import guess_next_version
19+
20+
print(
21+
f"computing version for {BUILD_TYPE} build with "
22+
f"{'dirty' if version.dirty else 'clean'} local repository"
23+
f"{' and exact version from tag' if version.exact else ''}",
24+
file=sys.stderr,
25+
)
26+
27+
if BUILD_TYPE == "nightly":
28+
# Nightly builds use alpha versions to ensure they are marked
29+
# as pre-releases on pypi.org.
30+
return version.format_next_version(
31+
guess_next=guess_next_version,
32+
fmt="{guessed}.a{node_date:%Y%m%d}",
33+
)
34+
35+
if (
36+
BUILD_TYPE == "release"
37+
and not version.dirty
38+
and (version.exact or version.node is None)
39+
):
40+
# When we have a tagged version, use that without modification.
41+
return version.format_with("{tag}")
42+
43+
# In development mode or when the local repository is dirty, treat
44+
# it is as local development version.
45+
return version.format_next_version(
46+
guess_next=guess_next_version,
47+
fmt="{guessed}.dev{distance}",
48+
)
49+
50+
51+
def localversion_func(version: ScmVersion) -> str:
52+
from setuptools_scm.version import get_local_node_and_date
53+
54+
print(
55+
f"computing local version for {BUILD_TYPE} build with "
56+
f"{'dirty' if version.dirty else 'clean'} local repository"
57+
"f{' and exact version from tag' if version.exact else ''}",
58+
file=sys.stderr,
59+
)
60+
61+
# When we are building nightly versions, we guess the next release
62+
# and add the date as an alpha version. We cannot publish packages
63+
# with local versions, so we do not add one.
64+
if BUILD_TYPE == "nightly":
65+
return ""
66+
67+
# When we have an exact tag, with no local changes, do not append
68+
# anything to the local version field.
69+
if (
70+
BUILD_TYPE == "release"
71+
and not version.dirty
72+
and (version.exact or version.node is None)
73+
):
74+
return ""
75+
76+
# In development mode or when the local repository is dirty,
77+
# return a string that includes the git SHA (node) and a date,
78+
# formatted as a local version tag.
79+
return get_local_node_and_date(version)
2380

2481

2582
setup(
26-
name=package_name,
27-
version=version_info.version,
83+
name="llmcompressor",
84+
use_scm_version={
85+
"version_scheme": version_func,
86+
"local_scheme": localversion_func,
87+
},
2888
author="Neuralmagic, Inc.",
2989
author_email="support@neuralmagic.com",
3090
description=(
@@ -62,8 +122,8 @@
62122
"pillow",
63123
(
64124
"compressed-tensors==0.9.3"
65-
if version_info.build_type == "release"
66-
else "compressed-tensors-nightly"
125+
if BUILD_TYPE == "release"
126+
else "compressed-tensors>=0.9.4a2"
67127
),
68128
],
69129
extras_require={

src/llmcompressor/__init__.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,11 @@
99
# flake8: noqa
1010

1111
from .logger import LoggerConfig, configure_logger, logger
12-
from .version import (
13-
__version__,
14-
build_type,
15-
version,
16-
version_base,
17-
version_build,
18-
version_major,
19-
version_minor,
20-
version_patch,
21-
)
12+
from .version import __version__, version
2213

2314
__all__ = [
2415
"__version__",
25-
"version_base",
26-
"build_type",
2716
"version",
28-
"version_major",
29-
"version_minor",
30-
"version_patch",
31-
"version_build",
3217
"configure_logger",
3318
"logger",
3419
"LoggerConfig",

src/llmcompressor/entrypoints/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ def post_process(
9999

100100
else:
101101
logger.warning(
102-
"Optimized model is not saved. To save, please provide",
103-
"`output_dir` as input arg.",
104-
"Ex. `oneshot(..., output_dir=...)`",
102+
"Optimized model is not saved. To save, please provide"
103+
"`output_dir` as input arg."
104+
"Ex. `oneshot(..., output_dir=...)`"
105105
)
106106

107107
# Reset the one-time-use session upon completion

src/llmcompressor/logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def configure_logger(config: Optional[LoggerConfig] = None) -> None:
7676
if (log_file := os.getenv("LLM_COMPRESSOR_LOG_FILE")) is not None:
7777
logger_config.log_file = log_file
7878
if (log_file_level := os.getenv("LLM_COMPRESSOR_LOG_FILE_LEVEL")) is not None:
79-
logger_config.log_file_level = log_file_level
79+
logger_config.log_file_level = log_file_level.upper()
8080

8181
if logger_config.disabled:
8282
logger.disable("llmcompressor")

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from transformers.utils.fx import HFTracer
1414

1515
from llmcompressor.modifiers.utils.hooks import HooksMixin
16-
from llmcompressor.utils.helpers import calibration_forward_context, preserve_attr
16+
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
1717

1818
__all__ = ["trace_subgraphs", "Subgraph"]
1919

@@ -132,15 +132,14 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
132132

133133
def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
134134
if isinstance(root, Module):
135-
with preserve_attr(type(root), "forward"):
136-
# due to a bug in Tracer.create_args_for_root (_patch_function),
137-
# we must unwrap function wrappers prior to tracing, for example
138-
# the `deprecate_kwarg` by transformers which wraps forward
139-
140-
# we override the class method because the
141-
# class method is the one being traced
142-
type(root).forward = inspect.unwrap(type(root).forward)
143-
135+
# due to a bug in Tracer.create_args_for_root (_patch_function),
136+
# we must unwrap function wrappers prior to tracing, for example
137+
# the `deprecate_kwarg` by transformers which wraps forward
138+
unwrapped_forward = inspect.unwrap(type(root).forward)
139+
140+
# we override the class method because the
141+
# class method is the one being traced
142+
with patch_attr(type(root), "forward", unwrapped_forward):
144143
return super().trace(root, *args, **kwargs)
145144

146145
else:

0 commit comments

Comments
 (0)