Skip to content

Commit 7dd71b9

Browse files
committed
add dispatch utility
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 6d942cc commit 7dd71b9

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from llmcompressor.modifiers.quantization import GPTQModifier
55
from llmcompressor.transformers import oneshot
6+
from llmcompressor.utils.dev import dispatch_for_generation
67

78
# Select model and load it.
89
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -61,18 +62,17 @@ def tokenize(sample):
6162
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
6263
)
6364

64-
# Save to disk compressed.
65-
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
66-
model.save_pretrained(SAVE_DIR, save_compressed=True)
67-
tokenizer.save_pretrained(SAVE_DIR)
68-
69-
# Load model after saving
70-
model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, device_map="auto")
71-
7265
# Confirm generations of the quantized model look sane.
7366
print("\n\n")
7467
print("========== SAMPLE GENERATION ==============")
75-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
76-
output = model.generate(input_ids, max_new_tokens=100)
68+
dispatch_for_generation(model)
69+
sample = tokenizer("Hello my name is", return_tensors="pt")
70+
sample = {key: value.to("cuda") for key, value in sample.items()}
71+
output = model.generate(**sample, max_new_tokens=100)
7772
print(tokenizer.decode(output[0]))
7873
print("==========================================\n\n")
74+
75+
# Save to disk compressed.
76+
SAVE_DIR = model_id.split("/")[-1] + "-W4A16-G128"
77+
model.save_pretrained(SAVE_DIR, save_compressed=True)
78+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/utils/dev.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
import logging
33
import os
44
import tempfile
5-
from typing import Type
5+
from typing import Type, Dict, Any, Union
66

77
import torch
88
from huggingface_hub import snapshot_download
99
from safetensors.torch import save_file
1010
from transformers import AutoModelForCausalLM, PreTrainedModel
1111
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
1212
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
13+
from accelerate import dispatch_model, infer_auto_device_map
14+
from accelerate.utils import get_balanced_memory
1315

1416
from llmcompressor.utils.helpers import patch_attr
1517

16-
__all__ = ["skip_weights_download", "patch_transformers_logger_level"]
18+
__all__ = ["skip_weights_download", "patch_transformers_logger_level", "dispatch_for_generation"]
1719

1820

1921
@contextlib.contextmanager
@@ -106,3 +108,14 @@ def patch_transformers_logger_level(level: int = logging.ERROR):
106108
transformers_logger.setLevel(level=level)
107109
yield
108110
transformers_logger.setLevel(level=restore_log_level)
111+
112+
113+
def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
114+
max_memory = get_balanced_memory(
115+
model,
116+
dtype=model.dtype,
117+
no_split_module_classes=model._get_no_split_modules("auto")
118+
)
119+
device_map = infer_auto_device_map(model, dtype=model.dtype, max_memory=max_memory)
120+
121+
return dispatch_model(model, device_map=device_map)

0 commit comments

Comments
 (0)