Skip to content

Commit 0348243

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/sequential-onloading
2 parents 929f678 + 421bd61 commit 0348243

File tree

7 files changed

+55
-46
lines changed

7 files changed

+55
-46
lines changed

examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,13 @@
6666
model=model,
6767
**oneshot_kwargs,
6868
stage="sparsity_stage",
69-
output_dir=output_dir,
7069
)
7170

7271
# Sparse finetune
7372
finetune_applied_model = train(
7473
model=oneshot_applied_model,
7574
**oneshot_kwargs,
7675
**training_kwargs,
77-
output_dir=output_dir,
7876
stage="finetuning_stage",
7977
)
8078

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def localversion_func(version: ScmVersion) -> str:
124124
(
125125
"compressed-tensors==0.9.4"
126126
if BUILD_TYPE == "release"
127-
else "compressed-tensors>=0.9.5a2"
127+
else "compressed-tensors>=0.10.1a2"
128128
),
129129
],
130130
extras_require={

src/llmcompressor/pytorch/model_load/helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def save_checkpoint(
4141
:param save_safetensors: save model checkpoint using safetensors file type
4242
:param save_compressed: save model checkpoint using compressed-tensors format
4343
"""
44+
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
45+
get_model_compressor, # avoid circular import
46+
)
47+
4448
# saving the model also saves the recipe
4549
model.save_pretrained(
4650
save_path,
@@ -51,6 +55,16 @@ def save_checkpoint(
5155
if processor is not None:
5256
processor.save_pretrained(save_path)
5357

58+
# saving the model modifies the model strcuture
59+
# as this is only a checkpoint, decompress model to enable future training/oneshot
60+
compressor = get_model_compressor(
61+
model=model,
62+
save_compressed=save_compressed,
63+
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
64+
)
65+
if compressor is not None:
66+
compressor.decompress_model(model)
67+
5468

5569
def fallback_to_cpu(device: str) -> str:
5670
"""

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
import weakref
44
from functools import wraps
5-
from typing import Dict, Optional
5+
from typing import Optional
66

77
import torch
88
import transformers
@@ -91,45 +91,27 @@ def save_pretrained_wrapper(
9191
# https://github.com/huggingface/transformers/pull/30488
9292
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
9393

94-
# state_dict gets passed in as a kwarg for FSDP models
95-
state_dict = kwargs.pop("state_dict", None)
96-
if state_dict is None:
97-
logger.info("Fetching state_dict - this may take some time")
98-
state_dict = get_state_dict_offloaded_model(model)
99-
100-
logger.info("Fetching compressor")
94+
# compress model using compressor
10195
compressor = get_model_compressor(
10296
model=model,
10397
sparsity_config=sparsity_config,
10498
quantization_format=quantization_format,
10599
save_compressed=save_compressed,
106100
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
107-
state_dict=state_dict,
108101
disable_sparse_compression=disable_sparse_compression,
109102
)
103+
if compressor is not None:
104+
compressor.compress_model(model)
105+
106+
# save (compressed) model structure
107+
original_save_pretrained.__get__(model, model_class)(
108+
save_directory,
109+
safe_serialization=safe_serialization,
110+
**kwargs,
111+
)
110112

111-
if compressor is None:
112-
# model is not compressed or quantized, save as normal
113-
original_save_pretrained_func = original_save_pretrained.__get__(
114-
model, model_class
115-
)
116-
original_save_pretrained_func(
117-
save_directory, state_dict=state_dict, **kwargs
118-
)
119-
return
120-
121-
# make sure we're on the main process when saving
122-
if state_dict is not None and len(state_dict) > 0:
123-
compressed_state_dict = compressor.compress(
124-
model, state_dict, show_progress=True
125-
)
126-
logger.info("Saving compressed model to disk")
127-
original_save_pretrained.__get__(model, model_class)(
128-
save_directory,
129-
state_dict=compressed_state_dict,
130-
safe_serialization=safe_serialization,
131-
**kwargs,
132-
)
113+
# update config to reflect compression
114+
if compressor is not None:
133115
compressor.update_config(save_directory)
134116

135117
# update existing recipe
@@ -197,7 +179,6 @@ def get_model_compressor(
197179
quantization_format: Optional[str] = None,
198180
save_compressed: bool = True,
199181
skip_sparsity_compression_stats: bool = True,
200-
state_dict: Optional[Dict] = None,
201182
disable_sparse_compression: bool = False,
202183
):
203184
"""
@@ -211,12 +192,8 @@ def get_model_compressor(
211192
:param save_compressed: boolean representing to save in a compressed
212193
format
213194
:param skip_sparsity_compression_stats: bool allowing compression stats on std out
214-
:param state_dict: state_dict of the model
215195
:param disable_sparse_compression: bool to skip sparse compression
216196
"""
217-
# find offloaded state dict if none is provided
218-
if state_dict is None:
219-
state_dict = get_state_dict_offloaded_model(model)
220197

221198
if sparsity_config is None:
222199
"""
@@ -244,6 +221,8 @@ def get_model_compressor(
244221
)
245222
sparsity_config = None
246223
else:
224+
state_dict = get_state_dict_offloaded_model(model)
225+
247226
sparsity_config = SparsityConfigMetadata.from_pretrained(
248227
model,
249228
state_dict=state_dict,

tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from parameterized import parameterized_class
88
from transformers import AutoConfig
99

10+
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
11+
get_model_compressor,
12+
)
1013
from tests.testing_utils import parse_params, requires_gpu
1114

1215
CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_oneshot_configs"
@@ -34,17 +37,21 @@ def _test_oneshot_and_finetune(self):
3437
output_dir=self.output,
3538
)
3639

37-
train_args = dict(
38-
num_train_epochs=self.num_train_epochs,
39-
precision="bfloat16",
40-
bf16=True,
41-
)
4240
oneshot_model = oneshot(
4341
model=self.model,
4442
**oneshot_args,
4543
stage="test_oneshot_stage",
4644
)
4745

46+
compressor = get_model_compressor(model=oneshot_model, save_compressed=True)
47+
if compressor is not None:
48+
compressor.decompress_model(oneshot_model)
49+
50+
train_args = dict(
51+
num_train_epochs=self.num_train_epochs,
52+
precision="bfloat16",
53+
bf16=True,
54+
)
4855
train(
4956
model=oneshot_model,
5057
**oneshot_args,

tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def test_oneshot_and_finetune_with_tokenizer(self):
5555
concatenate_data=concatenate_data,
5656
splits=splits,
5757
tokenizer=tokenizer,
58-
output_dir=self.output,
5958
)
6059

6160
oneshot_model = oneshot(
@@ -70,6 +69,7 @@ def test_oneshot_and_finetune_with_tokenizer(self):
7069
max_steps=max_steps,
7170
stage="test_train_stage",
7271
**model_and_data_kwargs,
72+
output_dir=self.output,
7373
)
7474

7575
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(

tests/llmcompressor/transformers/obcq/test_obcq_completion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def labeled_dataloader(self, dataset_name, model_name):
3535
dataset_manager = TextGenerationDataset.load_from_registry(
3636
dataset_args.dataset,
3737
dataset_args=dataset_args,
38-
split="train",
38+
split=f"train[:{self.num_samples}]",
3939
processor=tokenizer,
4040
)
4141
calib_dataset = dataset_manager()
@@ -51,10 +51,14 @@ def _test_oneshot_completion(self, model_name: str = None):
5151
from llmcompressor import oneshot
5252
from llmcompressor.pytorch.model_load.helpers import get_session_model
5353
from llmcompressor.pytorch.utils import tensors_to_device
54+
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
55+
get_model_compressor, # avoid circular import
56+
)
5457

5558
oneshot(
5659
model=self.model,
5760
dataset=self.dataset,
61+
splits={"calibration": f"train[:{self.num_samples}]"},
5862
oneshot_device=self.device,
5963
recipe=self.recipe,
6064
max_seq_length=512,
@@ -65,6 +69,13 @@ def _test_oneshot_completion(self, model_name: str = None):
6569
)
6670

6771
first_tiny_model = get_session_model()
72+
compressor = get_model_compressor(
73+
model=first_tiny_model,
74+
save_compressed=True,
75+
skip_sparsity_compression_stats=False,
76+
)
77+
if compressor is not None:
78+
compressor.decompress_model(first_tiny_model)
6879

6980
dataset = "open_platypus"
7081

0 commit comments

Comments
 (0)