|
1 | 1 | import os
|
2 | 2 | from datetime import datetime
|
3 |
| -from typing import Optional |
| 3 | +from typing import TYPE_CHECKING, List, Optional, Union |
4 | 4 |
|
5 | 5 | from loguru import logger
|
6 | 6 | from torch.utils.data import DataLoader
|
7 |
| -from transformers import PreTrainedModel |
| 7 | +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin |
8 | 8 |
|
9 | 9 | from llmcompressor.args import parse_args
|
10 | 10 | from llmcompressor.core.session_functions import active_session
|
|
14 | 14 |
|
15 | 15 | __all__ = ["Oneshot", "oneshot"]
|
16 | 16 |
|
| 17 | +if TYPE_CHECKING: |
| 18 | + from datasets import Dataset, DatasetDict |
| 19 | + |
17 | 20 |
|
18 | 21 | class Oneshot:
|
19 | 22 | """
|
@@ -102,7 +105,8 @@ def __init__(
|
102 | 105 | :param recipe_args: RecipeArguments parameters, responsible for containing
|
103 | 106 | recipe-related parameters
|
104 | 107 | :param output_dir: Path to save the output model after carrying out oneshot
|
105 |
| -
|
| 108 | + :param log_dir: Path to save logs during oneshot run. |
| 109 | + Nothing is logged to file if None. |
106 | 110 | """
|
107 | 111 | # Set up logging
|
108 | 112 | if log_dir:
|
@@ -191,8 +195,119 @@ def apply_recipe_modifiers(
|
191 | 195 | session.finalize()
|
192 | 196 |
|
193 | 197 |
|
194 |
| -def oneshot(**kwargs) -> PreTrainedModel: |
195 |
| - one_shot = Oneshot(**kwargs) |
| 198 | +def oneshot( |
| 199 | + # Model arguments |
| 200 | + model: Union[str, PreTrainedModel], |
| 201 | + distill_teacher: Optional[str] = None, |
| 202 | + config_name: Optional[str] = None, |
| 203 | + tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None, |
| 204 | + processor: Optional[Union[str, ProcessorMixin]] = None, |
| 205 | + cache_dir: Optional[str] = None, |
| 206 | + use_auth_token: bool = False, |
| 207 | + precision: str = "auto", |
| 208 | + tie_word_embeddings: bool = False, |
| 209 | + trust_remote_code_model: bool = False, |
| 210 | + save_compressed: bool = True, |
| 211 | + oneshot_device: str = "cuda:0", |
| 212 | + model_revision: str = "main", |
| 213 | + # Recipe arguments |
| 214 | + recipe: Optional[Union[str, List[str]]] = None, |
| 215 | + recipe_args: Optional[List[str]] = None, |
| 216 | + clear_sparse_session: bool = False, |
| 217 | + stage: Optional[str] = None, |
| 218 | + # Dataset arguments |
| 219 | + dataset: Optional[Union[str, "Dataset", "DatasetDict"]] = None, |
| 220 | + dataset_config_name: Optional[str] = None, |
| 221 | + dataset_path: Optional[str] = None, |
| 222 | + num_calibration_samples: int = 512, |
| 223 | + shuffle_calibration_samples: bool = True, |
| 224 | + max_seq_length: int = 384, |
| 225 | + pad_to_max_length: bool = True, |
| 226 | + text_column: str = "text", |
| 227 | + concatenate_data: bool = False, |
| 228 | + streaming: bool = False, |
| 229 | + overwrite_cache: bool = False, |
| 230 | + preprocessing_num_workers: Optional[int] = None, |
| 231 | + min_tokens_per_module: Optional[float] = None, |
| 232 | + trust_remote_code_data: bool = False, |
| 233 | + # Miscellaneous arguments |
| 234 | + output_dir: Optional[str] = None, |
| 235 | + log_dir: Optional[str] = "sparse_logs", |
| 236 | + **kwargs, |
| 237 | +) -> PreTrainedModel: |
| 238 | + """ |
| 239 | + Performs oneshot calibration on a model. |
| 240 | +
|
| 241 | + # Model arguments |
| 242 | + :param model: A pretrained model identifier from huggingface.co/models or a path |
| 243 | + to a local model. Required parameter. |
| 244 | + :param distill_teacher: Teacher model (a trained text generation model) |
| 245 | + for distillation. |
| 246 | + :param config_name: Pretrained config name or path if not the same as |
| 247 | + model_name. |
| 248 | + :param tokenizer: Pretrained tokenizer name or path if not the same as |
| 249 | + model_name. |
| 250 | + :param processor: Pretrained processor name or path if not the same as |
| 251 | + model_name. |
| 252 | + :param cache_dir: Where to store the pretrained data from |
| 253 | + huggingface.co. |
| 254 | + :param use_auth_token: Whether to use Hugging Face auth token for private |
| 255 | + models. |
| 256 | + :param precision: Precision to cast model weights to, default to auto. |
| 257 | + :param tie_word_embeddings: Whether the model's input and output word embeddings |
| 258 | + should be tied. |
| 259 | + :param trust_remote_code_model: Whether to allow for custom models to execute |
| 260 | + their own modeling files. |
| 261 | + :param save_compressed: Whether to compress sparse models during save. |
| 262 | + :param oneshot_device: Device to run oneshot calibration on. |
| 263 | + :param model_revision: The specific model version to use (can be branch name, |
| 264 | + tag, or commit id). |
| 265 | +
|
| 266 | + # Recipe arguments |
| 267 | + :param recipe: Path to a LLM Compressor sparsification recipe. |
| 268 | + :param recipe_args: List of recipe arguments to evaluate, in the |
| 269 | + format "key1=value1", "key2=value2". |
| 270 | + :param clear_sparse_session: Whether to clear CompressionSession/ |
| 271 | + CompressionLifecycle data between runs. |
| 272 | + :param stage: The stage of the recipe to use for oneshot. |
| 273 | +
|
| 274 | + # Dataset arguments |
| 275 | + :param dataset: The name of the dataset to use (via the datasets |
| 276 | + library). |
| 277 | + :param dataset_config_name: The configuration name of the dataset |
| 278 | + to use. |
| 279 | + :param dataset_path: Path to a custom dataset. Supports json, csv, dvc. |
| 280 | + :param num_calibration_samples: Number of samples to use for one-shot |
| 281 | + calibration. |
| 282 | + :param shuffle_calibration_samples: Whether to shuffle the dataset before |
| 283 | + calibration. |
| 284 | + :param max_seq_length: Maximum total input sequence length after tokenization. |
| 285 | + :param pad_to_max_length: Whether to pad all samples to `max_seq_length`. |
| 286 | + :param text_column: Key to use as the `text` input to tokenizer/processor. |
| 287 | + :param concatenate_data: Whether to concatenate datapoints to fill |
| 288 | + max_seq_length. |
| 289 | + :param streaming: True to stream data from a cloud dataset. |
| 290 | + :param overwrite_cache: Whether to overwrite the cached preprocessed datasets. |
| 291 | + :param preprocessing_num_workers: Number of processes for |
| 292 | + preprocessing. |
| 293 | + :param min_tokens_per_module: Minimum percentage of tokens per |
| 294 | + module, relevant for MoE models. |
| 295 | + :param trust_remote_code_data: Whether to allow for datasets defined on the Hub |
| 296 | + using a dataset script. |
| 297 | +
|
| 298 | + # Miscellaneous arguments |
| 299 | + :param output_dir: Path to save the output model after calibration. |
| 300 | + Nothing is saved if None. |
| 301 | + :param log_dir: Path to save logs during oneshot run. |
| 302 | + Nothing is logged to file if None. |
| 303 | +
|
| 304 | + :return: The calibrated PreTrainedModel |
| 305 | + """ |
| 306 | + |
| 307 | + # pass all args directly into Oneshot |
| 308 | + local_args = locals() |
| 309 | + local_args.pop("kwargs") |
| 310 | + one_shot = Oneshot(**local_args, **kwargs) |
196 | 311 | one_shot()
|
197 | 312 |
|
198 | 313 | return one_shot.model
|
0 commit comments