Skip to content

Commit 9d1bfb3

Browse files
author
George Ohashi
committed
fix
1 parent ad2733f commit 9d1bfb3

File tree

2 files changed

+6
-40
lines changed

2 files changed

+6
-40
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1+
from typing import Optional
12

2-
from typing import Optional, Tuple
3-
4-
from loguru import logger
53
from torch.utils.data import DataLoader
64
from transformers import PreTrainedModel
75

86
from llmcompressor.args import parse_args
97
from llmcompressor.core.session_functions import active_session
8+
from llmcompressor.datasets import get_calibration_dataloader
109
from llmcompressor.entrypoints.utils import post_process, pre_process
11-
from llmcompressor.transformers.finetune.data.data_helpers import (
12-
get_calibration_dataloader,
13-
)
14-
from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args
15-
1610

1711
__all__ = ["Oneshot", "oneshot"]
1812

@@ -68,7 +62,7 @@ class Oneshot:
6862
Initializes the `Oneshot` object by parsing input arguments, performing
6963
preprocessing, and setting instance attributes.
7064
71-
run(**kwargs):
65+
__call__(**kwargs):
7266
Performs the one-shot calibration process by preparing a calibration
7367
dataloader, applying recipe modifiers to the model, and executing
7468
postprocessing steps.
@@ -83,17 +77,6 @@ class Oneshot:
8377
defined in the recipe. Each action is executed via the global
8478
`CompressionSession`.
8579
86-
_pre_process():
87-
Handles preprocessing steps, including model initialization,
88-
tokenizer/processor setup, and resolving tied embedding issues.
89-
90-
check_tied_embeddings():
91-
Logs a warning if `tie_word_embeddings=True`, which may interfere with
92-
saving in the one-shot workflow.
93-
94-
_post_process():
95-
Executes postprocessing steps such as saving the model and resetting
96-
lifecycle actions, especially when a custom `output_dir` is specified.
9780
"""
9881

9982
def __init__(
@@ -182,23 +165,6 @@ def __call__(self):
182165
)
183166
post_process(model_args=self.model_args, output_dir=self.output_dir)
184167

185-
def save(self):
186-
"""
187-
Saves the model and tokenizer/processor to the output directory.
188-
189-
The model is saved in a compressed format if specified in `model_args`.
190-
The tokenizer or processor, if available, is also saved.
191-
192-
Raises:
193-
ValueError: If saving fails due to an invalid `output_dir` or other issues.
194-
"""
195-
self.model.save_pretrained(
196-
self.output_dir,
197-
save_compressed=self.model_args.save_compressed,
198-
)
199-
if self.processor is not None:
200-
self.processor.save_pretrained(self.output_dir)
201-
202168
def apply_recipe_modifiers(
203169
self,
204170
calibration_dataloader: Optional[DataLoader],

src/llmcompressor/transformers/finetune/text_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ def train(**kwargs):
4646
"""
4747
CLI entrypoint for running training
4848
"""
49-
model_args, dataset_args, recipe_args, training_args, _ = parse_args(**kwargs)
49+
model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs)
5050
training_args.do_train = True
5151
main(model_args, dataset_args, recipe_args, training_args)
5252

5353

54-
5554
@deprecated(
5655
message=(
5756
"`from llmcompressor.transformers import oneshot` is deprecated, "
@@ -69,10 +68,11 @@ def apply(**kwargs):
6968
CLI entrypoint for any of training, oneshot
7069
"""
7170
from llmcompressor.args import parse_args
71+
7272
model_args, dataset_args, recipe_args, training_args, _ = parse_args(
7373
include_training_args=True, **kwargs
7474
)
75-
75+
7676
training_args.run_stages = True
7777
report_to = kwargs.get("report_to", None)
7878
if report_to is None: # user didn't specify any reporters

0 commit comments

Comments
 (0)