1
- from pathlib import PosixPath
2
1
from typing import Optional
3
2
4
- from loguru import logger
5
3
from torch .utils .data import DataLoader
6
4
from transformers import PreTrainedModel
7
5
8
6
from llmcompressor .args import parse_args
9
7
from llmcompressor .core .session_functions import active_session
10
8
from llmcompressor .datasets import get_calibration_dataloader
11
- from llmcompressor .transformers .finetune .text_generation import (
12
- initialize_model_from_path ,
13
- initialize_processor_from_path ,
14
- )
15
- from llmcompressor .transformers .sparsification .compressed_tensors_utils import (
16
- modify_save_pretrained ,
17
- patch_tied_tensors_bug ,
18
- )
9
+ from llmcompressor .entrypoints .utils import post_process , pre_process
19
10
20
11
__all__ = ["Oneshot" , "oneshot" ]
21
12
@@ -71,7 +62,7 @@ class Oneshot:
71
62
Initializes the `Oneshot` object by parsing input arguments, performing
72
63
preprocessing, and setting instance attributes.
73
64
74
- run (**kwargs):
65
+ __call__ (**kwargs):
75
66
Performs the one-shot calibration process by preparing a calibration
76
67
dataloader, applying recipe modifiers to the model, and executing
77
68
postprocessing steps.
@@ -86,17 +77,6 @@ class Oneshot:
86
77
defined in the recipe. Each action is executed via the global
87
78
`CompressionSession`.
88
79
89
- _pre_process():
90
- Handles preprocessing steps, including model initialization,
91
- tokenizer/processor setup, and resolving tied embedding issues.
92
-
93
- check_tied_embeddings():
94
- Logs a warning if `tie_word_embeddings=True`, which may interfere with
95
- saving in the one-shot workflow.
96
-
97
- _post_process():
98
- Executes postprocessing steps such as saving the model and resetting
99
- lifecycle actions, especially when a custom `output_dir` is specified.
100
80
"""
101
81
102
82
def __init__ (
@@ -151,7 +131,7 @@ def from_args(
151
131
152
132
# only run for the first oneshot call
153
133
if do_preprocess :
154
- instance . _pre_process ( )
134
+ pre_process ( model_args )
155
135
156
136
# Set instance attributes
157
137
instance .model = instance .model_args .model
@@ -172,7 +152,7 @@ def __call__(self):
172
152
"""
173
153
# TODO: move back once stage runner is removed
174
154
# Preprocess the model and tokenizer/processor
175
- self ._pre_process ( )
155
+ pre_process ( self .model_args )
176
156
self .model = self .model_args .model
177
157
self .recipe = self .recipe_args .recipe
178
158
self .processor = self .model_args .processor
@@ -183,24 +163,7 @@ def __call__(self):
183
163
self .apply_recipe_modifiers (
184
164
calibration_dataloader = calibration_dataloader ,
185
165
)
186
- self ._post_process ()
187
-
188
- def save (self ):
189
- """
190
- Saves the model and tokenizer/processor to the output directory.
191
-
192
- The model is saved in a compressed format if specified in `model_args`.
193
- The tokenizer or processor, if available, is also saved.
194
-
195
- Raises:
196
- ValueError: If saving fails due to an invalid `output_dir` or other issues.
197
- """
198
- self .model .save_pretrained (
199
- self .output_dir ,
200
- save_compressed = self .model_args .save_compressed ,
201
- )
202
- if self .processor is not None :
203
- self .processor .save_pretrained (self .output_dir )
166
+ post_process (model_args = self .model_args , output_dir = self .output_dir )
204
167
205
168
def apply_recipe_modifiers (
206
169
self ,
@@ -236,75 +199,6 @@ def apply_recipe_modifiers(
236
199
session .initialize (** session_kwargs )
237
200
session .finalize (** session_kwargs )
238
201
239
- def _pre_process (self ):
240
- """
241
- Prepares the model and tokenizer/processor for calibration.
242
-
243
- - Initializes the model if it's specified as a path or string.
244
- - Applies patches to fix tied tensor issues and modifies `save_pretrained`
245
- behavior.
246
- - Initializes the processor if specified as a path or `None`.
247
- - Sets the minimum tokens per module if `dataset_args` are provided.
248
-
249
- Raises:
250
- FileNotFoundError: If the model or processor path is invalid.
251
- """
252
- self .check_tied_embeddings ()
253
-
254
- # Initialize model
255
- if isinstance (self .model_args .model , (str , PosixPath )):
256
- self .model_args .model , _ = initialize_model_from_path (self .model_args )
257
-
258
- patch_tied_tensors_bug (self .model_args .model )
259
- modify_save_pretrained (self .model_args .model )
260
-
261
- # Initialize processor
262
- if isinstance (self .model_args .processor , (str , type (None ))):
263
- self .model_args .processor = initialize_processor_from_path (
264
- self .model_args , self .model_args .model
265
- )
266
- # TODO: move to init once stage runner is removed
267
- self .processor = self .model_args .processor
268
-
269
- # Set minimum tokens per module if data arguments are provided
270
- if self .dataset_args :
271
- self .min_tokens_per_module = self .dataset_args .min_tokens_per_module
272
-
273
- def check_tied_embeddings (self ):
274
- """
275
- Logs a warning if the model has tied word embeddings.
276
-
277
- The `tie_word_embeddings` flag may cause issues during saving in the one-shot
278
- calibration workflow due to shared tensor addresses.
279
- """
280
- if self .model_args .tie_word_embeddings :
281
- logger .debug (
282
- "The tie_word_embeddings flag is by default set to False. "
283
- "This guarantees that the one-shot algorithm saves the final "
284
- "weights without errors. Detected tie_word_embeddings=True. "
285
- "This may cause issues with the one-shot algorithm on save."
286
- )
287
-
288
- def _post_process (self ):
289
- """
290
- Executes post-calibration steps.
291
-
292
- This method saves the model and resets lifecycle actions if the `output_dir`
293
- is not the default directory.
294
-
295
- Raises:
296
- ValueError: If saving fails due to invalid configurations.
297
- """
298
- if self .output_dir is not None :
299
- self .save ()
300
- return
301
-
302
- logger .warning (
303
- "Optimized model not saved. To save, please provide" ,
304
- "`output_dir` as input arg." ,
305
- "Ex. `oneshot(..., output_dir=...)`" ,
306
- )
307
-
308
202
309
203
def oneshot (** kwargs ) -> PreTrainedModel :
310
204
one_shot = Oneshot (** kwargs )
0 commit comments