Skip to content

Commit 9d82f35

Browse files
author
George
authored
[Training] Unifying Preprocess + Postprocessing logic for Train/Oneshot (#1212)
Order of reviews: #1206 #1207 #1209 #1212 <-- Here #1214 SUMMARY: * Move the preprocessing and postprocessing logic out of `src/llmcompressor/transformers/finetune/text_generation.py` and into `src/llmcompressor/entrypoints/utils.py` TEST PLAN: Pass tests
1 parent 14ac2e7 commit 9d82f35

File tree

4 files changed

+294
-290
lines changed

4 files changed

+294
-290
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# flake8: noqa
22
from .oneshot import Oneshot, oneshot
3+
from .utils import post_process, pre_process

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 5 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
1-
from pathlib import PosixPath
21
from typing import Optional
32

4-
from loguru import logger
53
from torch.utils.data import DataLoader
64
from transformers import PreTrainedModel
75

86
from llmcompressor.args import parse_args
97
from llmcompressor.core.session_functions import active_session
108
from llmcompressor.datasets import get_calibration_dataloader
11-
from llmcompressor.transformers.finetune.text_generation import (
12-
initialize_model_from_path,
13-
initialize_processor_from_path,
14-
)
15-
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
16-
modify_save_pretrained,
17-
patch_tied_tensors_bug,
18-
)
9+
from llmcompressor.entrypoints.utils import post_process, pre_process
1910

2011
__all__ = ["Oneshot", "oneshot"]
2112

@@ -71,7 +62,7 @@ class Oneshot:
7162
Initializes the `Oneshot` object by parsing input arguments, performing
7263
preprocessing, and setting instance attributes.
7364
74-
run(**kwargs):
65+
__call__(**kwargs):
7566
Performs the one-shot calibration process by preparing a calibration
7667
dataloader, applying recipe modifiers to the model, and executing
7768
postprocessing steps.
@@ -86,17 +77,6 @@ class Oneshot:
8677
defined in the recipe. Each action is executed via the global
8778
`CompressionSession`.
8879
89-
_pre_process():
90-
Handles preprocessing steps, including model initialization,
91-
tokenizer/processor setup, and resolving tied embedding issues.
92-
93-
check_tied_embeddings():
94-
Logs a warning if `tie_word_embeddings=True`, which may interfere with
95-
saving in the one-shot workflow.
96-
97-
_post_process():
98-
Executes postprocessing steps such as saving the model and resetting
99-
lifecycle actions, especially when a custom `output_dir` is specified.
10080
"""
10181

10282
def __init__(
@@ -151,7 +131,7 @@ def from_args(
151131

152132
# only run for the first oneshot call
153133
if do_preprocess:
154-
instance._pre_process()
134+
pre_process(model_args)
155135

156136
# Set instance attributes
157137
instance.model = instance.model_args.model
@@ -172,7 +152,7 @@ def __call__(self):
172152
"""
173153
# TODO: move back once stage runner is removed
174154
# Preprocess the model and tokenizer/processor
175-
self._pre_process()
155+
pre_process(self.model_args)
176156
self.model = self.model_args.model
177157
self.recipe = self.recipe_args.recipe
178158
self.processor = self.model_args.processor
@@ -183,24 +163,7 @@ def __call__(self):
183163
self.apply_recipe_modifiers(
184164
calibration_dataloader=calibration_dataloader,
185165
)
186-
self._post_process()
187-
188-
def save(self):
189-
"""
190-
Saves the model and tokenizer/processor to the output directory.
191-
192-
The model is saved in a compressed format if specified in `model_args`.
193-
The tokenizer or processor, if available, is also saved.
194-
195-
Raises:
196-
ValueError: If saving fails due to an invalid `output_dir` or other issues.
197-
"""
198-
self.model.save_pretrained(
199-
self.output_dir,
200-
save_compressed=self.model_args.save_compressed,
201-
)
202-
if self.processor is not None:
203-
self.processor.save_pretrained(self.output_dir)
166+
post_process(model_args=self.model_args, output_dir=self.output_dir)
204167

205168
def apply_recipe_modifiers(
206169
self,
@@ -236,75 +199,6 @@ def apply_recipe_modifiers(
236199
session.initialize(**session_kwargs)
237200
session.finalize(**session_kwargs)
238201

239-
def _pre_process(self):
240-
"""
241-
Prepares the model and tokenizer/processor for calibration.
242-
243-
- Initializes the model if it's specified as a path or string.
244-
- Applies patches to fix tied tensor issues and modifies `save_pretrained`
245-
behavior.
246-
- Initializes the processor if specified as a path or `None`.
247-
- Sets the minimum tokens per module if `dataset_args` are provided.
248-
249-
Raises:
250-
FileNotFoundError: If the model or processor path is invalid.
251-
"""
252-
self.check_tied_embeddings()
253-
254-
# Initialize model
255-
if isinstance(self.model_args.model, (str, PosixPath)):
256-
self.model_args.model, _ = initialize_model_from_path(self.model_args)
257-
258-
patch_tied_tensors_bug(self.model_args.model)
259-
modify_save_pretrained(self.model_args.model)
260-
261-
# Initialize processor
262-
if isinstance(self.model_args.processor, (str, type(None))):
263-
self.model_args.processor = initialize_processor_from_path(
264-
self.model_args, self.model_args.model
265-
)
266-
# TODO: move to init once stage runner is removed
267-
self.processor = self.model_args.processor
268-
269-
# Set minimum tokens per module if data arguments are provided
270-
if self.dataset_args:
271-
self.min_tokens_per_module = self.dataset_args.min_tokens_per_module
272-
273-
def check_tied_embeddings(self):
274-
"""
275-
Logs a warning if the model has tied word embeddings.
276-
277-
The `tie_word_embeddings` flag may cause issues during saving in the one-shot
278-
calibration workflow due to shared tensor addresses.
279-
"""
280-
if self.model_args.tie_word_embeddings:
281-
logger.debug(
282-
"The tie_word_embeddings flag is by default set to False. "
283-
"This guarantees that the one-shot algorithm saves the final "
284-
"weights without errors. Detected tie_word_embeddings=True. "
285-
"This may cause issues with the one-shot algorithm on save."
286-
)
287-
288-
def _post_process(self):
289-
"""
290-
Executes post-calibration steps.
291-
292-
This method saves the model and resets lifecycle actions if the `output_dir`
293-
is not the default directory.
294-
295-
Raises:
296-
ValueError: If saving fails due to invalid configurations.
297-
"""
298-
if self.output_dir is not None:
299-
self.save()
300-
return
301-
302-
logger.warning(
303-
"Optimized model not saved. To save, please provide",
304-
"`output_dir` as input arg.",
305-
"Ex. `oneshot(..., output_dir=...)`",
306-
)
307-
308202

309203
def oneshot(**kwargs) -> PreTrainedModel:
310204
one_shot = Oneshot(**kwargs)

0 commit comments

Comments
 (0)