Skip to content

Commit a18c76a

Browse files
oneshot entrypoint update (#1445)
SUMMARY: Updated the oneshot function to use explicit keyword arguments instead of relying solely on **kwargs. TEST PLAN: Changes were tested using mock tests that verify: Correct handling of explicit arguments Backward compatibility with kwargs Proper behavior with mixed explicit args and kwargs Verification that kwargs correctly override explicit parameters when duplicates exist fixes #1443 TODO: - [x] run nightly tests on branch -- https://github.com/neuralmagic/llm-compressor-testing/actions/runs/15710803245 ✅ --------- Co-authored-by: Brian Dellabetta <[email protected]>
1 parent 1c4f639 commit a18c76a

File tree

1 file changed

+120
-5
lines changed

1 file changed

+120
-5
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
from datetime import datetime
3-
from typing import Optional
3+
from typing import TYPE_CHECKING, List, Optional, Union
44

55
from loguru import logger
66
from torch.utils.data import DataLoader
7-
from transformers import PreTrainedModel
7+
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin
88

99
from llmcompressor.args import parse_args
1010
from llmcompressor.core.session_functions import active_session
@@ -14,6 +14,9 @@
1414

1515
__all__ = ["Oneshot", "oneshot"]
1616

17+
if TYPE_CHECKING:
18+
from datasets import Dataset, DatasetDict
19+
1720

1821
class Oneshot:
1922
"""
@@ -102,7 +105,8 @@ def __init__(
102105
:param recipe_args: RecipeArguments parameters, responsible for containing
103106
recipe-related parameters
104107
:param output_dir: Path to save the output model after carrying out oneshot
105-
108+
:param log_dir: Path to save logs during oneshot run.
109+
Nothing is logged to file if None.
106110
"""
107111
# Set up logging
108112
if log_dir:
@@ -191,8 +195,119 @@ def apply_recipe_modifiers(
191195
session.finalize()
192196

193197

194-
def oneshot(**kwargs) -> PreTrainedModel:
195-
one_shot = Oneshot(**kwargs)
198+
def oneshot(
199+
# Model arguments
200+
model: Union[str, PreTrainedModel],
201+
distill_teacher: Optional[str] = None,
202+
config_name: Optional[str] = None,
203+
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
204+
processor: Optional[Union[str, ProcessorMixin]] = None,
205+
cache_dir: Optional[str] = None,
206+
use_auth_token: bool = False,
207+
precision: str = "auto",
208+
tie_word_embeddings: bool = False,
209+
trust_remote_code_model: bool = False,
210+
save_compressed: bool = True,
211+
oneshot_device: str = "cuda:0",
212+
model_revision: str = "main",
213+
# Recipe arguments
214+
recipe: Optional[Union[str, List[str]]] = None,
215+
recipe_args: Optional[List[str]] = None,
216+
clear_sparse_session: bool = False,
217+
stage: Optional[str] = None,
218+
# Dataset arguments
219+
dataset: Optional[Union[str, "Dataset", "DatasetDict"]] = None,
220+
dataset_config_name: Optional[str] = None,
221+
dataset_path: Optional[str] = None,
222+
num_calibration_samples: int = 512,
223+
shuffle_calibration_samples: bool = True,
224+
max_seq_length: int = 384,
225+
pad_to_max_length: bool = True,
226+
text_column: str = "text",
227+
concatenate_data: bool = False,
228+
streaming: bool = False,
229+
overwrite_cache: bool = False,
230+
preprocessing_num_workers: Optional[int] = None,
231+
min_tokens_per_module: Optional[float] = None,
232+
trust_remote_code_data: bool = False,
233+
# Miscellaneous arguments
234+
output_dir: Optional[str] = None,
235+
log_dir: Optional[str] = "sparse_logs",
236+
**kwargs,
237+
) -> PreTrainedModel:
238+
"""
239+
Performs oneshot calibration on a model.
240+
241+
# Model arguments
242+
:param model: A pretrained model identifier from huggingface.co/models or a path
243+
to a local model. Required parameter.
244+
:param distill_teacher: Teacher model (a trained text generation model)
245+
for distillation.
246+
:param config_name: Pretrained config name or path if not the same as
247+
model_name.
248+
:param tokenizer: Pretrained tokenizer name or path if not the same as
249+
model_name.
250+
:param processor: Pretrained processor name or path if not the same as
251+
model_name.
252+
:param cache_dir: Where to store the pretrained data from
253+
huggingface.co.
254+
:param use_auth_token: Whether to use Hugging Face auth token for private
255+
models.
256+
:param precision: Precision to cast model weights to, default to auto.
257+
:param tie_word_embeddings: Whether the model's input and output word embeddings
258+
should be tied.
259+
:param trust_remote_code_model: Whether to allow for custom models to execute
260+
their own modeling files.
261+
:param save_compressed: Whether to compress sparse models during save.
262+
:param oneshot_device: Device to run oneshot calibration on.
263+
:param model_revision: The specific model version to use (can be branch name,
264+
tag, or commit id).
265+
266+
# Recipe arguments
267+
:param recipe: Path to a LLM Compressor sparsification recipe.
268+
:param recipe_args: List of recipe arguments to evaluate, in the
269+
format "key1=value1", "key2=value2".
270+
:param clear_sparse_session: Whether to clear CompressionSession/
271+
CompressionLifecycle data between runs.
272+
:param stage: The stage of the recipe to use for oneshot.
273+
274+
# Dataset arguments
275+
:param dataset: The name of the dataset to use (via the datasets
276+
library).
277+
:param dataset_config_name: The configuration name of the dataset
278+
to use.
279+
:param dataset_path: Path to a custom dataset. Supports json, csv, dvc.
280+
:param num_calibration_samples: Number of samples to use for one-shot
281+
calibration.
282+
:param shuffle_calibration_samples: Whether to shuffle the dataset before
283+
calibration.
284+
:param max_seq_length: Maximum total input sequence length after tokenization.
285+
:param pad_to_max_length: Whether to pad all samples to `max_seq_length`.
286+
:param text_column: Key to use as the `text` input to tokenizer/processor.
287+
:param concatenate_data: Whether to concatenate datapoints to fill
288+
max_seq_length.
289+
:param streaming: True to stream data from a cloud dataset.
290+
:param overwrite_cache: Whether to overwrite the cached preprocessed datasets.
291+
:param preprocessing_num_workers: Number of processes for
292+
preprocessing.
293+
:param min_tokens_per_module: Minimum percentage of tokens per
294+
module, relevant for MoE models.
295+
:param trust_remote_code_data: Whether to allow for datasets defined on the Hub
296+
using a dataset script.
297+
298+
# Miscellaneous arguments
299+
:param output_dir: Path to save the output model after calibration.
300+
Nothing is saved if None.
301+
:param log_dir: Path to save logs during oneshot run.
302+
Nothing is logged to file if None.
303+
304+
:return: The calibrated PreTrainedModel
305+
"""
306+
307+
# pass all args directly into Oneshot
308+
local_args = locals()
309+
local_args.pop("kwargs")
310+
one_shot = Oneshot(**local_args, **kwargs)
196311
one_shot()
197312

198313
return one_shot.model

0 commit comments

Comments
 (0)