Skip to content

Commit 21de943

Browse files
[Entrypoints] initialize processor error handling (#1796)
SUMMARY: Resolves #1795 * #1795 Currently, we initialize a processor in entrypoint `pre_process` even if one isn't provided, even though it isn't needed for data-free recipes like `FP8_DYNAMIC` or `W4A16`, causing downstream user issues like #1795. This updates pre-processing to - wrap processor initialization in a try/catch - error out if initialization fails and a processor is required (i.e. if a dataset is needed for training/calibration) - otherwise, log a warning if an output_dir is provided, because the processor will not be saved with the trained/compressed model. TEST PLAN: Example script in #1795 succeeds on this branch, confirmed error is raised if `output_dir` is set and error is raised if `dataset` is set. --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 4c9ac83 commit 21de943

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,6 @@ class DatasetArguments(CustomDatasetArguments):
217217
"Default is set to True."
218218
},
219219
)
220+
221+
def is_dataset_provided(self) -> bool:
222+
return self.dataset is not None or self.dataset_path is not None

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125
self.output_dir = output_dir
126126

127127
# initialize the model and processor
128-
pre_process(model_args)
128+
pre_process(model_args, dataset_args, output_dir)
129129

130130
# Set instance attributes
131131
self.model = self.model_args.model

src/llmcompressor/entrypoints/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ def train(**kwargs) -> PreTrainedModel:
5959
```
6060
6161
"""
62-
model_args, dataset_args, recipe_args, training_args, _ = parse_args(
62+
model_args, dataset_args, recipe_args, training_args, output_dir = parse_args(
6363
include_training_args=True, **kwargs
6464
)
6565

66-
pre_process(model_args)
66+
pre_process(model_args, dataset_args, output_dir)
6767
dispatch_for_generation(model_args.model) # train is dispatched same as generation
6868

6969
processed_dataset = get_processed_dataset(

src/llmcompressor/entrypoints/utils.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
)
1616
from transformers.utils.quantization_config import CompressedTensorsConfig
1717

18-
from llmcompressor.args import ModelArguments, RecipeArguments, TrainingArguments
18+
from llmcompressor.args import (
19+
DatasetArguments,
20+
ModelArguments,
21+
RecipeArguments,
22+
TrainingArguments,
23+
)
1924
from llmcompressor.core import reset_session
2025
from llmcompressor.pytorch.model_load.helpers import parse_dtype
2126
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
@@ -30,7 +35,11 @@
3035
from llmcompressor.utils.fsdp.helpers import is_fsdp_model
3136

3237

33-
def pre_process(model_args: "ModelArguments"):
38+
def pre_process(
39+
model_args: ModelArguments,
40+
dataset_args: DatasetArguments,
41+
output_dir: Optional[str],
42+
):
3443
"""
3544
Prepares the model and tokenizer/processor for calibration.
3645
- Initializes the model if it's specified as a path or string.
@@ -54,11 +63,27 @@ def pre_process(model_args: "ModelArguments"):
5463
model_args.model = model
5564
model_args.distill_teacher = distill_teacher
5665

57-
# Initialize processor
66+
# Initialize processor if dataset provided
5867
if isinstance(model_args.processor, (str, type(None))):
59-
model_args.processor = initialize_processor_from_path(
60-
model_args, model_args.model
61-
)
68+
try:
69+
model_args.processor = initialize_processor_from_path(
70+
model_args, model_args.model
71+
)
72+
except Exception as e:
73+
if dataset_args.is_dataset_provided():
74+
raise RuntimeError(
75+
"An error occurred when attempting to initialize "
76+
"model processor, which is required when a dataset "
77+
"is provided. To resolve, create and pass in a "
78+
"processor directly to `oneshot`/`train`."
79+
) from e
80+
elif output_dir:
81+
logger.warning(
82+
"Model processor could not be auto-initialized and "
83+
"will not be saved along with the model. To resolve, "
84+
"create and pass in a processor directly to "
85+
f"`oneshot`/`train`.\nInitialization Error: {e}"
86+
)
6287

6388
# untie tie_word_embeddings weights
6489
if not model_args.tie_word_embeddings:

0 commit comments

Comments
 (0)