Skip to content

Commit 907c413

Browse files
author
George
authored
Merge branch 'main' into datasetargs
2 parents 1a3dafe + 7bb517f commit 907c413

File tree

4 files changed

+82
-71
lines changed

4 files changed

+82
-71
lines changed

src/llmcompressor/args/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .model_arguments import ModelArguments
55
from .recipe_arguments import RecipeArguments
66
from .training_arguments import TrainingArguments
7+
from .utils import parse_args

src/llmcompressor/args/utils.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Tuple
2+
3+
from loguru import logger
4+
from transformers import HfArgumentParser
5+
6+
from llmcompressor.args import (
7+
DatasetArguments,
8+
ModelArguments,
9+
RecipeArguments,
10+
TrainingArguments,
11+
)
12+
from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args
13+
14+
15+
def parse_args(
16+
include_training_args: bool = False, **kwargs
17+
) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments, str]:
18+
"""
19+
Keyword arguments passed in from `oneshot` or `train` will
20+
separate the arguments into the following:
21+
22+
* ModelArguments in
23+
src/llmcompressor/args/model_args.py
24+
* DatasetArguments in
25+
src/llmcompressor/args/dataset_args.py
26+
* RecipeArguments in
27+
src/llmcompressor/args/recipe_args.py
28+
* TrainingArguments in
29+
src/llmcompressor/args/training_args.py
30+
31+
ModelArguments, DatasetArguments, and RecipeArguments are used for both
32+
`oneshot` and `train`. TrainingArguments is only used for `train`.
33+
34+
"""
35+
36+
# pop output_dir, used as an attr in TrainingArguments, where oneshot is not used
37+
output_dir = kwargs.pop("output_dir", None)
38+
39+
parser_args = (ModelArguments, DatasetArguments, RecipeArguments)
40+
if include_training_args:
41+
parser_args += (TrainingArguments,)
42+
43+
parser = HfArgumentParser(parser_args)
44+
parsed_args = parser.parse_dict(kwargs)
45+
46+
training_args = None
47+
if include_training_args:
48+
model_args, dataset_args, recipe_args, training_args = parsed_args
49+
if output_dir is not None:
50+
training_args.output_dir = output_dir
51+
else:
52+
model_args, dataset_args, recipe_args = parsed_args
53+
54+
if recipe_args.recipe_args is not None:
55+
if not isinstance(recipe_args.recipe_args, dict):
56+
arg_dict = {}
57+
for recipe_arg in recipe_args.recipe_args:
58+
key, value = recipe_arg.split("=")
59+
arg_dict[key] = value
60+
recipe_args.recipe_args = arg_dict
61+
62+
# raise depreciation warnings
63+
if dataset_args.remove_columns is not None:
64+
logger.warn(
65+
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
66+
"columns which are invalid inputs the tokenizer will be removed",
67+
DeprecationWarning,
68+
)
69+
70+
# silently assign tokenizer to processor
71+
resolve_processor_from_model_args(model_args)
72+
73+
return model_args, dataset_args, recipe_args, training_args, output_dir

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 6 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from pathlib import PosixPath
2-
from typing import Optional, Tuple
2+
from typing import Optional
33

44
from loguru import logger
55
from torch.utils.data import DataLoader
6-
from transformers import HfArgumentParser, PreTrainedModel
6+
from transformers import PreTrainedModel
77

8-
from llmcompressor.args import DatasetArguments, ModelArguments, RecipeArguments
8+
from llmcompressor.args import parse_args
99
from llmcompressor.core.session_functions import active_session
1010
from llmcompressor.transformers.finetune.data.data_helpers import (
1111
get_calibration_dataloader,
@@ -18,9 +18,8 @@
1818
modify_save_pretrained,
1919
patch_tied_tensors_bug,
2020
)
21-
from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args
2221

23-
__all__ = ["Oneshot", "oneshot", "parse_oneshot_args"]
22+
__all__ = ["Oneshot", "oneshot"]
2423

2524

2625
class Oneshot:
@@ -122,11 +121,10 @@ def __init__(
122121
:param output_dir: Path to save the output model after carrying out oneshot
123122
124123
"""
125-
126-
model_args, dataset_args, recipe_args, output_dir = parse_oneshot_args(**kwargs)
124+
model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs)
127125

128126
self.model_args = model_args
129-
self.dataset_args = dataset_args
127+
self.data_args = dataset_args
130128
self.recipe_args = recipe_args
131129
self.output_dir = output_dir
132130

@@ -315,64 +313,3 @@ def oneshot(**kwargs) -> PreTrainedModel:
315313
one_shot()
316314

317315
return one_shot.model
318-
319-
320-
def parse_oneshot_args(
321-
**kwargs,
322-
) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, str]:
323-
"""
324-
Parses kwargs by grouping into model, data or training arg groups:
325-
* model_args in
326-
src/llmcompressor/transformers/utils/arg_parser/model_args.py
327-
* dataset_args in
328-
src/llmcompressor/transformers/utils/arg_parser/dataset_args.py
329-
* recipe_args in
330-
src/llmcompressor/transformers/utils/arg_parser/recipe_args.py
331-
* training_args in
332-
src/llmcompressor/transformers/utils/arg_parser/training_args.py
333-
"""
334-
output_dir = kwargs.pop("output_dir", None)
335-
336-
parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments))
337-
338-
if not kwargs:
339-
340-
def _get_output_dir_from_argv() -> Optional[str]:
341-
import sys
342-
343-
output_dir = None
344-
if "--output_dir" in sys.argv:
345-
index = sys.argv.index("--output_dir")
346-
sys.argv.pop(index)
347-
if index < len(sys.argv): # Check if value exists afer the flag
348-
output_dir = sys.argv.pop(index)
349-
350-
return output_dir
351-
352-
output_dir = _get_output_dir_from_argv() or output_dir
353-
parsed_args = parser.parse_args_into_dataclasses()
354-
else:
355-
parsed_args = parser.parse_dict(kwargs)
356-
357-
model_args, dataset_args, recipe_args = parsed_args
358-
359-
if recipe_args.recipe_args is not None:
360-
if not isinstance(recipe_args.recipe_args, dict):
361-
arg_dict = {}
362-
for recipe_arg in recipe_args.recipe_args:
363-
key, value = recipe_arg.split("=")
364-
arg_dict[key] = value
365-
recipe_args.recipe_args = arg_dict
366-
367-
# raise depreciation warnings
368-
if dataset_args.remove_columns is not None:
369-
logger.warning(
370-
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
371-
"columns which are invalid inputs the tokenizer will be removed",
372-
DeprecationWarning,
373-
)
374-
375-
# silently assign tokenizer to processor
376-
resolve_processor_from_model_args(model_args)
377-
378-
return model_args, dataset_args, recipe_args, output_dir

tests/llmcompressor/entrypoints/test_oneshot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from transformers import AutoModelForCausalLM
22

33
from llmcompressor import Oneshot
4-
from llmcompressor.entrypoints.oneshot import parse_oneshot_args
4+
from llmcompressor.args import parse_args
55

66

77
def test_oneshot_from_args():
@@ -17,7 +17,7 @@ def test_oneshot_from_args():
1717

1818
output_dir = "bar_output_dir"
1919

20-
model_args, dataset_args, recipe_args, output_dir = parse_oneshot_args(
20+
model_args, data_args, recipe_args, _, output_dir = parse_args(
2121
model=model,
2222
dataset=dataset,
2323
recipe=recipe,

0 commit comments

Comments
 (0)