Skip to content

Commit 75fab6f

Browse files
authored
Merge branch 'main' into add-non-uniform-e2e-tests
2 parents 4baf690 + bd111bc commit 75fab6f

File tree

6 files changed

+92
-6
lines changed

6 files changed

+92
-6
lines changed

experimental/attention/llama3_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
12
from datasets import load_dataset
23
from transformers import AutoModelForCausalLM, AutoTokenizer
34

45
from llmcompressor import oneshot
56
from llmcompressor.modifiers.quantization import QuantizationModifier
67
from llmcompressor.utils import dispatch_for_generation
7-
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
88

99
# Select model and load it.
1010
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

experimental/attention/llama3_attention_r3_nvfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
from compressed_tensors.quantization import QuantizationScheme
2+
from compressed_tensors.quantization.quant_scheme import NVFP4
13
from datasets import load_dataset
24
from transformers import AutoModelForCausalLM, AutoTokenizer
35

46
from llmcompressor import oneshot
57
from llmcompressor.modifiers.quantization import QuantizationModifier
68
from llmcompressor.modifiers.transform import SpinQuantModifier
79
from llmcompressor.utils import dispatch_for_generation
8-
from compressed_tensors.quantization import QuantizationScheme
9-
from compressed_tensors.quantization.quant_scheme import NVFP4
1010

1111
# Select model and load it.
1212
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from datasets import Dataset, DatasetDict
3232

3333

34+
TOKENIZERS_PARALLELISM_ENV = "TOKENIZERS_PARALLELISM"
35+
36+
3437
class Oneshot:
3538
"""
3639
Class responsible for carrying out one-shot calibration on a pretrained model.
@@ -121,6 +124,19 @@ def __init__(
121124
:param log_dir: Path to save logs during oneshot run.
122125
Nothing is logged to file if None.
123126
"""
127+
# Disable tokenizer parallelism to prevent warning when using
128+
# multiprocessing for dataset preprocessing. The warning occurs because
129+
# FastTokenizer's internal threading conflicts with dataset.map's num_proc.
130+
# See: https://github.com/vllm-project/llm-compressor/issues/2007
131+
if TOKENIZERS_PARALLELISM_ENV not in os.environ:
132+
os.environ[TOKENIZERS_PARALLELISM_ENV] = "false"
133+
logger.warning(
134+
"Disabling tokenizer parallelism due to threading conflict between "
135+
"FastTokenizer and Datasets. Set "
136+
f"{TOKENIZERS_PARALLELISM_ENV}=false to "
137+
"suppress this warning."
138+
)
139+
124140
# Set up file logging (no default files):
125141
# 1) If LLM_COMPRESSOR_LOG_FILE is set, log to that file.
126142
# 2) Else, if an explicit log_dir is provided, create a timestamped file there.
@@ -213,6 +229,7 @@ def apply_recipe_modifiers(
213229
recipe_stage=recipe_stage,
214230
recipe_args=self.recipe_args.recipe_args,
215231
calib_data=calibration_dataloader,
232+
sequential_targets=self.dataset_args.sequential_targets,
216233
)
217234
user_pipeline = self.dataset_args.pipeline
218235
pipeline = CalibrationPipeline.from_modifiers(

src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
113113
dataloader: torch.utils.data.DataLoader = state.data.calib
114114

115115
# infer module and sequential targets
116-
self.sequential_targets = self._infer_sequential_targets(model)
116+
# Note: only pass sequential_targets from kwargs, not the full kwargs dict
117+
# which may contain 'model' and cause duplicate argument errors
118+
self.sequential_targets = self._infer_sequential_targets(
119+
model, sequential_targets=kwargs.get("sequential_targets")
120+
)
117121
layers = get_layers(self.sequential_targets, model)
118122
self._target_layers = get_layers(
119123
self.targets, model
@@ -192,9 +196,27 @@ def on_end(self, state: State, event: Event, **kwargs):
192196
self.ended_ = True
193197
self.remove_hooks()
194198

195-
def _infer_sequential_targets(self, model: torch.nn.Module) -> str | list[str]:
199+
def _infer_sequential_targets(
200+
self, model: torch.nn.Module, **kwargs
201+
) -> str | list[str]:
202+
targets_from_kwargs = kwargs.get("sequential_targets")
203+
204+
# Validate that sequential_targets is not provided from both sources
205+
if self.sequential_targets is not None and targets_from_kwargs is not None:
206+
raise ValueError(
207+
"sequential_targets was provided both in the modifier config and in "
208+
"oneshot() dataset_args. Please provide sequential_targets in only "
209+
"one location to avoid conflicts."
210+
)
211+
196212
match self.sequential_targets:
197213
case None:
214+
# Check if sequential_targets was passed via kwargs (from dataset_args)
215+
if targets_from_kwargs is not None:
216+
if isinstance(targets_from_kwargs, str):
217+
return [targets_from_kwargs]
218+
return targets_from_kwargs
219+
# Fall back to auto-inference
198220
return get_no_split_params(model)
199221
case str():
200222
return [self.sequential_targets]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
3+
import pytest
4+
5+
from llmcompressor.entrypoints.oneshot import (
6+
TOKENIZERS_PARALLELISM_ENV as _TOKENIZERS_PARALLELISM_ENV,
7+
)
8+
9+
10+
class TestTokenizerParallelism:
11+
"""Tests for tokenizer parallelism warning suppression (issue #2007)."""
12+
13+
def test_oneshot_sets_tokenizers_parallelism_when_not_set(self, monkeypatch):
14+
"""
15+
Test that Oneshot sets TOKENIZERS_PARALLELISM=false when not already set.
16+
17+
This prevents the warning:
18+
"huggingface/tokenizers: The current process just got forked, after
19+
parallelism has already been used. Disabling parallelism to avoid deadlocks..."
20+
21+
See: https://github.com/vllm-project/llm-compressor/issues/2007
22+
"""
23+
monkeypatch.delenv(_TOKENIZERS_PARALLELISM_ENV, raising=False)
24+
25+
from llmcompressor.entrypoints.oneshot import Oneshot
26+
27+
# Create a minimal Oneshot instance to trigger __init__
28+
# We expect it to fail due to missing model, but the env var should be set
29+
with pytest.raises(Exception):
30+
Oneshot(model="nonexistent-model")
31+
32+
assert os.environ[_TOKENIZERS_PARALLELISM_ENV] == "false"
33+
34+
def test_oneshot_respects_existing_tokenizers_parallelism(self, monkeypatch):
35+
"""
36+
Test that Oneshot respects user's existing TOKENIZERS_PARALLELISM setting.
37+
38+
If a user has explicitly set TOKENIZERS_PARALLELISM, we should not override it.
39+
"""
40+
monkeypatch.setenv(_TOKENIZERS_PARALLELISM_ENV, "true")
41+
42+
from llmcompressor.entrypoints.oneshot import Oneshot
43+
44+
with pytest.raises(Exception):
45+
Oneshot(model="nonexistent-model")
46+
47+
assert os.environ[_TOKENIZERS_PARALLELISM_ENV] == "true"

tools/collect_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
creating bug reports. See `.github/ISSUE_TEMPLATE/bug_report.md`
44
"""
55

6+
import importlib
67
import platform
78
import sys
8-
import importlib
99

1010

1111
def get_version(pkg_name):

0 commit comments

Comments
 (0)