Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
26bcbb0
oneshot refac
Jan 28, 2025
8c0a255
comments
Feb 4, 2025
468a714
Merge branch 'main' into oneshot-refac-main
Feb 4, 2025
428a2c7
merge main
Feb 11, 2025
58b3d6a
stashed changes
Feb 11, 2025
380d164
run examples pass
Feb 11, 2025
0ae9ade
pass tests
Feb 11, 2025
9be9174
Merge branch 'main' into oneshot-refac-main
Feb 11, 2025
95bfaa1
add entrypoints
Feb 11, 2025
c4dd9cc
Merge branch 'oneshot-refac-main' of github.com:vllm-project/llm-comp…
Feb 11, 2025
d2ffc4a
udpate read me on /finetune
Feb 11, 2025
fc4c42f
pass tests
Feb 11, 2025
09942b7
pass tests
Feb 11, 2025
00dd629
add readme and remove breakpoint
Feb 12, 2025
723b6b5
update read me
Feb 12, 2025
d140f11
comments
Feb 12, 2025
e208a69
GPUS to GPUs
Feb 12, 2025
d1855cf
add deprecation warning
Feb 12, 2025
af8515d
update readme
Feb 13, 2025
3f7c3ac
update sparse24 in readme
Feb 13, 2025
a8517b0
-s
Feb 13, 2025
0dd5987
update bf16
Feb 13, 2025
1d97841
comments
Feb 13, 2025
b5112cd
fix bug on processor
Feb 13, 2025
1929938
Merge branch 'main' into oneshot-refac-main
Feb 13, 2025
7402143
fix preprocess logic
Feb 14, 2025
0f86cf3
Merge branch 'oneshot-refac-main' of github.com:vllm-project/llm-comp…
Feb 14, 2025
58a7a5a
fix self attr population
Feb 14, 2025
ca7da03
fix test, get torch model not stub
Feb 14, 2025
103fd71
use non-gated model
Feb 14, 2025
7cf5f1a
Merge branch 'main' into oneshot-refac-main
Feb 14, 2025
333dc42
lint
Feb 17, 2025
e7e838f
Merge branch 'main' into oneshot-refac-main
Feb 17, 2025
5a1dccf
update stages
Feb 19, 2025
fb2af8d
Merge branch 'main' into oneshot-refac-main
Feb 19, 2025
ca9f295
revert output_dir name
Feb 19, 2025
a9e8597
Merge branch 'oneshot-refac-main' of github.com:vllm-project/llm-comp…
Feb 19, 2025
d2e6274
Merge branch 'main' into oneshot-refac-main
Feb 19, 2025
5f1d383
Merge branch 'main' into oneshot-refac-main
Feb 20, 2025
eb3094c
comments
Feb 20, 2025
b3a09a4
remove if condition
Feb 24, 2025
fe6b797
Merge branch 'main' into oneshot-refac-main
Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,19 +200,6 @@ def finalize(self, **kwargs) -> ModifiedState:
modifier_data=mod_data,
)

def apply(self, **kwargs):
"""
Apply the recipe in one-shot manner. This will invoke the initialize
and then finalize methods for each modifier in the session's lifecycle.
This will also set the session's state to the finalized state.

:param kwargs: additional kwargs to pass to the lifecycle's initialize and
finalize methods
"""
self.initialize(**kwargs)

return self.finalize(**kwargs)

def event(
self,
event_type: EventType,
Expand Down
57 changes: 0 additions & 57 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
"callbacks",
"LifecycleCallbacks",
]
Expand Down Expand Up @@ -143,62 +142,6 @@ def finalize(**kwargs) -> ModifiedState:
return active_session().finalize(**kwargs)


def apply(
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
recipe_stage: Union[str, List[str], None] = None,
recipe_args: Optional[Dict[str, Any]] = None,
model: Optional[Any] = None,
teacher_model: Optional[Any] = None,
train_data: Optional[Any] = None,
val_data: Optional[Any] = None,
test_data: Optional[Any] = None,
calib_data: Optional[Any] = None,
copy_data: bool = True,
start: Optional[float] = None,
steps_per_epoch: Optional[int] = None,
batches_per_step: Optional[int] = None,
**kwargs,
) -> ModifiedState:
"""
A method to apply the recipe in one-shot manner. This will invoke the initialize
and then finalize methods for each modifier in the active session's lifecycle.

:param recipe: the recipe to use for the sparsification, can be a path to a
recipe file, a raw recipe string, a recipe object, or a list of recipe objects.
:param recipe_stage: the stage to target for the sparsification
:param recipe_args: the args to use for overriding the recipe defaults
:param model: the model to sparsify
:param teacher_model: the teacher model to use for knowledge distillation
:param train_data: the training data to use for the sparsification
:param val_data: the validation data to use for the sparsification
:param test_data: the testing data to use for the sparsification
:param calib_data: the calibration data to use for the sparsification
:param copy_data: True to copy the data, False otherwise
:param start: the start epoch to use for the sparsification
:param steps_per_epoch: the number of steps per epoch to use for the
sparsification
:param batches_per_step: the number of batches per step to use for
:param kwargs: additional kwargs to pass to the current session's apply method
:return: the modified state of the active session after applying the recipe
"""
return active_session().apply(
recipe=recipe,
recipe_stage=recipe_stage,
recipe_args=recipe_args,
model=model,
teacher_model=teacher_model,
train_data=train_data,
val_data=val_data,
test_data=test_data,
calib_data=calib_data,
copy_data=copy_data,
start=start,
steps_per_epoch=steps_per_epoch,
batches_per_step=batches_per_step,
**kwargs,
)


class LifecycleCallbacks:
"""
A class for invoking lifecycle events for the active session
Expand Down
3 changes: 3 additions & 0 deletions src/llmcompressor/transformers/calibration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .oneshot import Oneshot
263 changes: 263 additions & 0 deletions src/llmcompressor/transformers/calibration/oneshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from torch.utils.data import DataLoader

from llmcompressor.core.session_functions import active_session
from llmcompressor.transformers.finetune.data.data_helpers import (
get_calibration_dataloader,
)
from llmcompressor.transformers.finetune.text_generation import (
initialize_model_from_path,
initialize_processor_from_path,
parse_args,
)
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
)
from llmcompressor.transformers.utils.arg_parser import DEFAULT_OUTPUT_DIR

__all__ = ["Oneshot"]


class Oneshot:
"""
Class responsible for carrying out one-shot calibration on a pretrained model.

This class handles the entire lifecycle of one-shot calibration, including
preprocessing (model and tokenizer/processor initialization), model optimization
(quantization or sparsification), and postprocessing (saving outputs). The
intructions for model optimization can be specified by using a recipe (fine-grain
details) or by using a scheme (ex. W4A16, W8A8, W4A8).

- **Input Keyword Arguments:**
`kwargs` are parsed into:
- `model_args`: Arguments for loading and configuring a pretrained model
(e.g., `AutoModelForCausalLM`).
- `data_args`: Arguments for dataset-related configurations, such as
calibration dataloaders.
- `recipe_args`: Arguments for defining and configuring recipes that specify
optimization actions.

Parsers are defined in `src/llmcompressor/transformers/utils/arg_parser`.

- **Lifecycle Overview:**
The calibration lifecycle consists of three steps:
1. **Preprocessing**:
- Instantiates a pretrained model and tokenizer/processor.
- Ensures input and output embedding layers are untied if they share
tensors.
- Patches the model to include additional functionality for saving with
quantization configurations.
2. **Oneshot Calibration**:
- Optimizes the model using a global `CompressionSession` and applies
recipe-defined modifiers (e.g., `GPTQModifier`, `SparseGPTModifier`)
3. **Postprocessing**:
- Saves the model, tokenizer/processor, and configuration to the specified
`output_dir`.

- **Usage:**
```python
oneshot = Oneshot(model=model, recipe=recipe, dataset=dataset)
oneshot.run()

# Access the processed components
model = oneshot.model
tokenizer_or_processor = oneshot.tokenizer_or_processor
recipe = oneshot.recipe
```

Methods:
__init__(**kwargs):
Initializes the `Oneshot` object by parsing input arguments, performing
preprocessing, and setting instance attributes.

run(**kwargs):
Performs the one-shot calibration process by preparing a calibration
dataloader, applying recipe modifiers to the model, and executing
postprocessing steps.

save():
Saves the calibrated model and tokenizer/processor to the specified
`output_dir`. Supports saving in compressed formats based on model
arguments.

_apply_recipe_modifiers(calibration_dataloader, **kwargs):
Applies lifecycle actions (e.g., `initialize`, `finalize`) using modifiers
defined in the recipe. Each action is executed via the global
`CompressionSession`.

_pre_process():
Handles preprocessing steps, including model initialization,
tokenizer/processor setup, and resolving tied embedding issues.

_warn_tied_embeddings():
Logs a warning if `tie_word_embeddings=True`, which may interfere with
saving in the one-shot workflow.

_post_process():
Executes postprocessing steps such as saving the model and resetting
lifecycle actions, especially when a custom `output_dir` is specified.
"""

MODIFIER_LIFECYCLE_ACTIONS = (
"initialize",
"finalize",
)

def __init__(self, **kwargs):
"""
Initializes the `Oneshot` class with provided arguments.

Parses the input keyword arguments into `model_args`, `data_args`, and
`recipe_args`. Performs preprocessing to initialize the model and
tokenizer/processor.

Args:
kwargs: Arbitrary keyword arguments for model, data, and recipe
configurations.
"""
self.model_args, self.data_args, self.recipe_args, _, self.output_dir = (
parse_args(**kwargs)
)

# Preprocess the model and tokenizer/processor
self._pre_process()

# Set instance attributes
self.model = self.model_args.model
self.tokenizer_or_processor = self.model_args.processor
self.recipe = self.recipe_args.recipe

def run(self, **kwargs):
"""
Performs one-shot calibration.

This method prepares a calibration dataloader using dataset arguments and
applies recipe-based modifiers to optimize the model. The lifecycle actions
are executed sequentially, and the modified model is saved during
postprocessing.

Args:
kwargs: Additional keyword arguments for the recipe modifiers.
"""
calibration_dataloader = get_calibration_dataloader(
self.data_args, self.tokenizer_or_processor
)
self._apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader, **kwargs
)
self._post_process()

def save(self):
"""
Saves the model and tokenizer/processor to the output directory.

The model is saved in a compressed format if specified in `model_args`.
The tokenizer or processor, if available, is also saved.

Raises:
ValueError: If saving fails due to an invalid `output_dir` or other issues.
"""
self.model.save_pretrained(
self.output_dir,
save_compressed=self.model_args.save_compressed,
)
if self.tokenizer_or_processor:
self.tokenizer_or_processor.save_pretrained(self.output_dir)

def _apply_recipe_modifiers(
self, calibration_dataloader: Optional[DataLoader], **kwargs
):
"""
Applies recipe modifiers to the model during the lifecycle.

The modifiers are defined in the recipe and executed via lifecycle actions
(`initialize`, `finalize`) through the global `CompressionSession`.

Args:
calibration_dataloader (Optional[DataLoader]): Dataloader for calibration
data.
kwargs: Additional arguments for lifecycle actions.

Raises:
RuntimeError: If any modifier fails during execution.
"""
for action in self.MODIFIER_LIFECYCLE_ACTIONS:
session = active_session()
session_action = getattr(session, action)
session_action(
model=self.model,
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
**kwargs,
)

def _pre_process(self):
"""
Prepares the model and tokenizer/processor for calibration.

- Initializes the model if it's specified as a path or string.
- Applies patches to fix tied tensor issues and modifies `save_pretrained`
behavior.
- Initializes the processor if specified as a path or `None`.
- Sets the minimum tokens per module if `data_args` are provided.

Raises:
FileNotFoundError: If the model or processor path is invalid.
"""
self._warn_tied_embeddings()

# Initialize model
if isinstance(self.model_args.model, (str, PosixPath)):
self.model_args.model, _ = initialize_model_from_path(self.model_args)

patch_tied_tensors_bug(self.model_args.model)
modify_save_pretrained(self.model_args.model)

# Initialize processor
if isinstance(self.model_args.processor, (str, type(None))):
self.model_args.processor = initialize_processor_from_path(
self.model_args, self.model_args.model
)

# Set minimum tokens per module if data arguments are provided
if self.data_args:
self.min_tokens_per_module = self.data_args.min_tokens_per_module

def _warn_tied_embeddings(self):
"""
Logs a warning if the model has tied word embeddings.

The `tie_word_embeddings` flag may cause issues during saving in the one-shot
calibration workflow due to shared tensor addresses.
"""
if self.model_args.tie_word_embeddings:
logger.debug(
"The tie_word_embeddings flag is by default set to False. "
"This guarantees that the one-shot algorithm saves the final "
"weights without errors. Detected tie_word_embeddings=True. "
"This may cause issues with the one-shot algorithm on save."
)

def _post_process(self):
"""
Executes post-calibration steps.

This method saves the model and resets lifecycle actions if the `output_dir`
is not the default directory.

Raises:
ValueError: If saving fails due to invalid configurations.
"""
if (
isinstance(self.model_args.model, str)
or self.output_dir != DEFAULT_OUTPUT_DIR
):
self.save()
Loading