Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5cdae11
Dataclass Arg refactor -- recipe_args
Jan 27, 2025
a7cf946
refactor dataclass args
Jan 27, 2025
6859adc
apply changes only for dataclass args - recipe, model, dataset, training
Jan 28, 2025
7d19625
Merge branch 'main' into oneshot-refac-recipe_args
Jan 28, 2025
6a1e4b0
fix tests
Jan 28, 2025
44c67d7
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
Jan 28, 2025
7bb2e9a
pass cli tests
Jan 28, 2025
0adb755
Merge branch 'main' into oneshot-refac-recipe_args
Jan 29, 2025
12a2167
remove redundant code
Jan 29, 2025
cacfbed
Merge branch 'main' into oneshot-refac-recipe_args
Jan 31, 2025
77205c0
comments
Jan 31, 2025
377a10b
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
Jan 31, 2025
84ddf07
add type annotations to private func
Jan 31, 2025
3770e6c
fix tests
Feb 3, 2025
1f3110b
Merge branch 'main' into oneshot-refac-recipe_args
Feb 4, 2025
dbf7f8c
move to util
Feb 4, 2025
656824e
fix tests
Feb 5, 2025
48e382f
remove redudant code
Feb 5, 2025
be31960
examples TrainingArguments movement
Feb 5, 2025
1253435
Merge branch 'main' into oneshot-refac-recipe_args
Feb 5, 2025
f4491be
revert to only refactor wrt to dataclass
Feb 6, 2025
13ee157
remove unnec code
Feb 6, 2025
48f531f
Merge branch 'main' into oneshot-refac-recipe_args
Feb 6, 2025
ce42137
Merge branch 'main' into oneshot-refac-recipe_args
Feb 6, 2025
2fb8212
change optional to required in session_mixin
Feb 6, 2025
ef8fae0
Merge branch 'oneshot-refac-recipe_args' of github.com:vllm-project/l…
Feb 6, 2025
f0fc214
fix
Feb 6, 2025
6e4c8cc
fix test
Feb 7, 2025
bcbfd35
comments
Feb 10, 2025
049ddec
add
Feb 10, 2025
e671817
Merge branch 'main' into oneshot-refac-recipe_args
Feb 10, 2025
68dd3b4
consistency
Feb 10, 2025
63862c3
Update src/llmcompressor/arg_parser/README.md
Feb 10, 2025
afb8efa
comment
Feb 10, 2025
d50baba
comments
Feb 10, 2025
9e26589
rename data_arguments to dataset_arguments
Feb 10, 2025
7f49448
comments
Feb 10, 2025
a55a427
change directory name
Feb 10, 2025
319d1bd
fix
Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from sft_trainer import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator

from llmcompressor.transformers import (
DataTrainingArguments,
TextGenerationDataset,
from llmcompressor.transformers import TextGenerationDataset
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
TrainingArguments,
)

Expand All @@ -21,7 +21,7 @@
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load gsm8k using SparseML dataset tools
data_args = DataTrainingArguments(
data_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=512
)
dataset_manager = TextGenerationDataset.load_from_registry(
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from datasets.formatting.formatting import LazyRow
from loguru import logger

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.transformers.utils.arg_parser import DatasetArguments
from llmcompressor.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DataTrainingArguments,
data_args: DatasetArguments,
split: str,
processor: Processor,
):
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="c4")
Expand All @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="cnn_dailymail")
Expand All @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset):

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="evolcodealpaca")
Expand All @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
"\n\n### Response:\n"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "theblackcat102/evol-codealpaca-v1"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/flickr_30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="flickr", alias="flickr30k")
Expand All @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "lmms-lab/flickr30k"

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="gsm8k")
Expand All @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset):

GSM_TEMPLATE = "Question: {question}\nAnswer:"

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "gsm8k"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/open_platypus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="open_platypus")
Expand All @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset):
"instruction}\n\n### Response:\n",
}

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "garage-bAInd/Open-Platypus"
data_args.text_column = "text"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ptb")
Expand All @@ -18,7 +18,7 @@ class PtbDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "ptb_text_only"
data_args.text_column = "sentence"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="ultrachat_200k")
Expand All @@ -33,7 +33,7 @@ class UltraChatDataset(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "HuggingFaceH4/ultrachat_200k"
data_args.text_column = "messages"
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/finetune/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llmcompressor.typing import Processor

if TYPE_CHECKING:
from llmcompressor.transformers import DataTrainingArguments as DataArgs
from llmcompressor.transformers.utils.arg_parser import DatasetArguments


@TextGenerationDataset.register(name="wikitext")
Expand All @@ -18,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset):
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "Salesforce/wikitext"
data_args.text_column = "text"
Expand Down
40 changes: 29 additions & 11 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
from llmcompressor.pytorch.utils import tensors_to_device
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
make_dataset_splits,
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.training_args import TrainingArguments
from llmcompressor.transformers.utils.arg_parser import (
DatasetArguments,
ModelArguments,
RecipeArguments,
TrainingArguments,
)
from llmcompressor.transformers.utils.arg_parser.training_arguments import (
DEFAULT_OUTPUT_DIR,
)
from llmcompressor.transformers.utils.arg_parser.utils import get_dataclass_as_dict
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe

Expand All @@ -46,13 +53,15 @@ class StageRunner:

def __init__(
self,
data_args: "DataTrainingArguments",
data_args: "DatasetArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
recipe_args: "RecipeArguments",
):
self._data_args = data_args
self._model_args = model_args
self._training_args = training_args
self._recipe_args = recipe_args

self.datasets = {}
self.trainer = None
Expand Down Expand Up @@ -214,7 +223,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
:param checkpoint: optional checkpoint to pick up a stage from
"""

recipe_obj = Recipe.create_instance(self._training_args.recipe)
recipe_obj = Recipe.create_instance(self._recipe_args.recipe)
with self.trainer.accelerator.main_process_first():
checkpoint_dir = self._model_args.model
completed_stages = get_completed_stages(checkpoint_dir)
Expand Down Expand Up @@ -251,21 +260,30 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):

# run stage
if run_type is StageRunType.ONESHOT:
self.one_shot(stage=stage_name)
from llmcompressor.transformers.calibration import Oneshot

model = get_session_model()
self._model_args.model = model

oneshot = Oneshot(
output_dir=self._training_args.output_dir,
**get_dataclass_as_dict(self._model_args, ModelArguments),
**get_dataclass_as_dict(self._data_args, DatasetArguments),
**get_dataclass_as_dict(self._recipe_args, RecipeArguments),
)

oneshot.run(stage_name=stage_name)
elif run_type is StageRunType.TRAIN:
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None

if (
self._training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
):
if self._training_args.output_dir != DEFAULT_OUTPUT_DIR:
save_model_and_recipe(
model=self.trainer.model,
save_path=self._output_dir,
processor=self.processor,
save_safetensors=self._training_args.save_safetensors,
save_compressed=self._training_args.save_compressed,
save_compressed=self._model_args.save_compressed,
)

# save stage to checkpoint dir
Expand Down
Loading