Skip to content

Commit dc7ac1a

Browse files
committed
use untie_word_embeddings
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 4b4257f commit dc7ac1a

File tree

3 files changed

+43
-56
lines changed

3 files changed

+43
-56
lines changed

src/llmcompressor/entrypoints/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from llmcompressor.pytorch.model_load.helpers import parse_dtype
2121
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
2222
modify_save_pretrained,
23-
patch_tied_tensors_bug,
23+
untie_word_embeddings,
2424
)
2525
from llmcompressor.transformers.utils.helpers import (
2626
detect_last_checkpoint,
@@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"):
6161
)
6262

6363
# untie tie_word_embeddings weights
64-
patch_tied_tensors_bug(model_args.model)
64+
if not model_args.tie_word_embeddings:
65+
untie_word_embeddings(model_args.model)
6566

6667
# wrap model.save_pretrained
6768
modify_save_pretrained(model_args.model)
@@ -143,7 +144,6 @@ def initialize_model_from_path(
143144
cache_dir=model_args.cache_dir,
144145
revision=model_args.model_revision,
145146
use_auth_token=True if model_args.use_auth_token else None,
146-
tie_word_embeddings=model_args.tie_word_embeddings,
147147
trust_remote_code=model_args.trust_remote_code_model,
148148
)
149149

@@ -156,7 +156,6 @@ def initialize_model_from_path(
156156
AutoConfig.from_pretrained(
157157
model_args.distill_teacher,
158158
use_auth_token=True if model_args.use_auth_token else None,
159-
tie_word_embeddings=model_args.tie_word_embeddings,
160159
trust_remote_code=model_args.trust_remote_code_model,
161160
)
162161
if model_args.distill_teacher

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
CompressionFormat,
1010
ModelCompressor,
1111
SparsityCompressionConfig,
12+
delete_offload_parameter,
1213
is_module_offloaded,
13-
update_offload_parameter,
14+
register_offload_parameter,
1415
)
1516
from loguru import logger
1617
from safetensors.torch import storage_ptr
@@ -27,7 +28,7 @@
2728
from llmcompressor.transformers.utils import RECIPE_FILE_NAME
2829
from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path
2930

30-
__all__ = ["modify_save_pretrained"]
31+
__all__ = ["modify_save_pretrained", "untie_word_embeddings"]
3132

3233

3334
def modify_save_pretrained(model: PreTrainedModel):
@@ -120,7 +121,7 @@ def save_pretrained_wrapper(
120121
model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
121122

122123

123-
def patch_tied_tensors_bug(model: torch.nn.Module):
124+
def untie_word_embeddings(model: PreTrainedModel):
124125
"""
125126
Patches bug where HF transformers will fail to untie weights under specific
126127
circumstances (https://github.com/huggingface/transformers/issues/33689).
@@ -129,28 +130,27 @@ def patch_tied_tensors_bug(model: torch.nn.Module):
129130
130131
:param model: model to fix
131132
"""
132-
if (
133-
hasattr(model.config, "tie_word_embeddings")
134-
and not model.config.tie_word_embeddings
135-
):
136-
input_embed = model.get_input_embeddings()
137-
output_embed = model.get_output_embeddings()
138-
139-
if input_embed is None or output_embed is None:
140-
# some models fail to properly override the abstract methods
141-
return
142-
143-
if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight):
144-
for module in (input_embed, output_embed):
145-
if not is_module_offloaded(module):
146-
# create new storage ptr for onloaded weight
147-
untied_data = module.weight.data.clone()
148-
module.weight.data = untied_data
149-
else:
150-
# create new storage ptr for offloaded weight
151-
# note `update_offload_parameter` does not create a new storage ptr
152-
untied_data = module._hf_hook.weights_map["weight"].clone()
153-
update_offload_parameter(module, "weight", untied_data)
133+
input_embed = model.get_input_embeddings()
134+
output_embed = model.get_output_embeddings()
135+
136+
for module in (input_embed, output_embed):
137+
if module is None or not hasattr(module, "weight"):
138+
logger.warning(f"Cannot untie {module} which does not have weight param")
139+
continue
140+
141+
# this could be replaced by a `get_offloaded_parameter` util
142+
if not is_module_offloaded(module):
143+
untied_data = module.weight.data.clone()
144+
else:
145+
untied_data = module._hf_hook.weights_map["weight"].clone()
146+
147+
requires_grad = module.weight.requires_grad
148+
new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad)
149+
delete_offload_parameter(module, "weight")
150+
register_offload_parameter(module, "weight", new_parameter)
151+
152+
if hasattr(model.config, "tie_word_embeddings"):
153+
model.config.tie_word_embeddings = False
154154

155155

156156
def get_model_compressor(

tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
2929
get_model_compressor,
3030
modify_save_pretrained,
31-
patch_tied_tensors_bug,
31+
untie_word_embeddings,
3232
)
3333
from tests.testing_utils import requires_gpu
3434

@@ -224,8 +224,6 @@ def test_quant_model_reload(format, dtype, tmp_path):
224224
shutil.rmtree(tmp_path)
225225

226226

227-
# technically only tie_word_embeddings=False is supported right now
228-
# setting to True is discouraged
229227
@pytest.mark.parametrize(
230228
"offload,torch_dtype,tie_word_embeddings,device",
231229
[
@@ -237,25 +235,23 @@ def test_quant_model_reload(format, dtype, tmp_path):
237235
# offloading
238236
(True, torch.float16, False, "cpu"),
239237
(True, torch.float32, False, "cpu"),
240-
# (True, torch.float16, True, "cpu"), # TODO: fails
241-
# (True, torch.float32, True, "cpu"), # TODO: fails
238+
(True, torch.float16, True, "cpu"),
239+
(True, torch.float32, True, "cpu"),
242240
],
243241
)
244242
def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path):
245243
model_path = "nm-testing/llama2.c-stories15M"
246244
save_path = tmp_path / "save_path"
247245

248-
model = AutoModelForCausalLM.from_pretrained(
249-
model_path,
250-
tie_word_embeddings=tie_word_embeddings,
251-
torch_dtype=torch_dtype,
252-
)
246+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
253247
if offload:
254248
model = dispatch_model(model, {"": device}, force_hooks=True)
255249
else:
256250
model = model.to(device)
257251

258-
patch_tied_tensors_bug(model)
252+
if not tie_word_embeddings:
253+
untie_word_embeddings(model)
254+
259255
modify_save_pretrained(model)
260256
model.save_pretrained(save_path, safe_serialization=True)
261257

@@ -294,22 +290,18 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp
294290
(True, torch.float32, True, "cpu"),
295291
],
296292
)
297-
def test_model_shared_tensors(
298-
offload, torch_dtype, tie_word_embeddings, device, tmp_path
299-
):
293+
def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device):
300294
# load model
301-
model = AutoModelForCausalLM.from_pretrained(
302-
"nm-testing/llama2.c-stories15M",
303-
torch_dtype=torch_dtype,
304-
tie_word_embeddings=tie_word_embeddings,
305-
)
306-
patch_tied_tensors_bug(model)
307-
295+
model_path = "nm-testing/llama2.c-stories15M"
296+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
308297
if offload:
309298
model = dispatch_model(model, {"": device}, force_hooks=True)
310299
else:
311300
model = model.to(device)
312301

302+
if not tie_word_embeddings:
303+
untie_word_embeddings(model)
304+
313305
# modify lm head
314306
with torch.no_grad(), align_module_device(model.lm_head):
315307
update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1)
@@ -332,12 +324,8 @@ def test_model_shared_tensors(
332324
(False, torch.float32, True, "cuda:0"),
333325
],
334326
)
335-
def test_model_shared_tensors_gpu(
336-
offload, torch_dtype, tie_word_embeddings, device, tmp_path
337-
):
338-
test_model_shared_tensors(
339-
offload, torch_dtype, tie_word_embeddings, device, tmp_path
340-
)
327+
def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device):
328+
test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device)
341329

342330

343331
@requires_gpu

0 commit comments

Comments
 (0)