1- from pathlib import PosixPath
21from typing import Optional
32
4- from loguru import logger
53from torch .utils .data import DataLoader
64from transformers import PreTrainedModel
75
86from llmcompressor .args import parse_args
97from llmcompressor .core .session_functions import active_session
108from llmcompressor .datasets import get_calibration_dataloader
11- from llmcompressor .transformers .finetune .text_generation import (
12- initialize_model_from_path ,
13- initialize_processor_from_path ,
14- )
15- from llmcompressor .transformers .sparsification .compressed_tensors_utils import (
16- modify_save_pretrained ,
17- patch_tied_tensors_bug ,
18- )
9+ from llmcompressor .entrypoints .utils import post_process , pre_process
1910
2011__all__ = ["Oneshot" , "oneshot" ]
2112
@@ -71,7 +62,7 @@ class Oneshot:
7162 Initializes the `Oneshot` object by parsing input arguments, performing
7263 preprocessing, and setting instance attributes.
7364
74- run (**kwargs):
65+ __call__ (**kwargs):
7566 Performs the one-shot calibration process by preparing a calibration
7667 dataloader, applying recipe modifiers to the model, and executing
7768 postprocessing steps.
@@ -86,17 +77,6 @@ class Oneshot:
8677 defined in the recipe. Each action is executed via the global
8778 `CompressionSession`.
8879
89- _pre_process():
90- Handles preprocessing steps, including model initialization,
91- tokenizer/processor setup, and resolving tied embedding issues.
92-
93- check_tied_embeddings():
94- Logs a warning if `tie_word_embeddings=True`, which may interfere with
95- saving in the one-shot workflow.
96-
97- _post_process():
98- Executes postprocessing steps such as saving the model and resetting
99- lifecycle actions, especially when a custom `output_dir` is specified.
10080 """
10181
10282 def __init__ (
@@ -151,7 +131,7 @@ def from_args(
151131
152132 # only run for the first oneshot call
153133 if do_preprocess :
154- instance . _pre_process ( )
134+ pre_process ( model_args )
155135
156136 # Set instance attributes
157137 instance .model = instance .model_args .model
@@ -172,7 +152,7 @@ def __call__(self):
172152 """
173153 # TODO: move back once stage runner is removed
174154 # Preprocess the model and tokenizer/processor
175- self ._pre_process ( )
155+ pre_process ( self .model_args )
176156 self .model = self .model_args .model
177157 self .recipe = self .recipe_args .recipe
178158 self .processor = self .model_args .processor
@@ -183,24 +163,7 @@ def __call__(self):
183163 self .apply_recipe_modifiers (
184164 calibration_dataloader = calibration_dataloader ,
185165 )
186- self ._post_process ()
187-
188- def save (self ):
189- """
190- Saves the model and tokenizer/processor to the output directory.
191-
192- The model is saved in a compressed format if specified in `model_args`.
193- The tokenizer or processor, if available, is also saved.
194-
195- Raises:
196- ValueError: If saving fails due to an invalid `output_dir` or other issues.
197- """
198- self .model .save_pretrained (
199- self .output_dir ,
200- save_compressed = self .model_args .save_compressed ,
201- )
202- if self .processor is not None :
203- self .processor .save_pretrained (self .output_dir )
166+ post_process (model_args = self .model_args , output_dir = self .output_dir )
204167
205168 def apply_recipe_modifiers (
206169 self ,
@@ -236,75 +199,6 @@ def apply_recipe_modifiers(
236199 session .initialize (** session_kwargs )
237200 session .finalize (** session_kwargs )
238201
239- def _pre_process (self ):
240- """
241- Prepares the model and tokenizer/processor for calibration.
242-
243- - Initializes the model if it's specified as a path or string.
244- - Applies patches to fix tied tensor issues and modifies `save_pretrained`
245- behavior.
246- - Initializes the processor if specified as a path or `None`.
247- - Sets the minimum tokens per module if `dataset_args` are provided.
248-
249- Raises:
250- FileNotFoundError: If the model or processor path is invalid.
251- """
252- self .check_tied_embeddings ()
253-
254- # Initialize model
255- if isinstance (self .model_args .model , (str , PosixPath )):
256- self .model_args .model , _ = initialize_model_from_path (self .model_args )
257-
258- patch_tied_tensors_bug (self .model_args .model )
259- modify_save_pretrained (self .model_args .model )
260-
261- # Initialize processor
262- if isinstance (self .model_args .processor , (str , type (None ))):
263- self .model_args .processor = initialize_processor_from_path (
264- self .model_args , self .model_args .model
265- )
266- # TODO: move to init once stage runner is removed
267- self .processor = self .model_args .processor
268-
269- # Set minimum tokens per module if data arguments are provided
270- if self .dataset_args :
271- self .min_tokens_per_module = self .dataset_args .min_tokens_per_module
272-
273- def check_tied_embeddings (self ):
274- """
275- Logs a warning if the model has tied word embeddings.
276-
277- The `tie_word_embeddings` flag may cause issues during saving in the one-shot
278- calibration workflow due to shared tensor addresses.
279- """
280- if self .model_args .tie_word_embeddings :
281- logger .debug (
282- "The tie_word_embeddings flag is by default set to False. "
283- "This guarantees that the one-shot algorithm saves the final "
284- "weights without errors. Detected tie_word_embeddings=True. "
285- "This may cause issues with the one-shot algorithm on save."
286- )
287-
288- def _post_process (self ):
289- """
290- Executes post-calibration steps.
291-
292- This method saves the model and resets lifecycle actions if the `output_dir`
293- is not the default directory.
294-
295- Raises:
296- ValueError: If saving fails due to invalid configurations.
297- """
298- if self .output_dir is not None :
299- self .save ()
300- return
301-
302- logger .warning (
303- "Optimized model not saved. To save, please provide" ,
304- "`output_dir` as input arg." ,
305- "Ex. `oneshot(..., output_dir=...)`" ,
306- )
307-
308202
309203def oneshot (** kwargs ) -> PreTrainedModel :
310204 one_shot = Oneshot (** kwargs )
0 commit comments