Skip to content

Commit 8ba0f2c

Browse files
committed
apply style
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7dd71b9 commit 8ba0f2c

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@ def tokenize(sample):
7575
# Save to disk compressed.
7676
SAVE_DIR = model_id.split("/")[-1] + "-W4A16-G128"
7777
model.save_pretrained(SAVE_DIR, save_compressed=True)
78-
tokenizer.save_pretrained(SAVE_DIR)
78+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/utils/dev.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22
import logging
33
import os
44
import tempfile
5-
from typing import Type, Dict, Any, Union
5+
from typing import Type
66

77
import torch
8+
from accelerate import dispatch_model, infer_auto_device_map
9+
from accelerate.utils import get_balanced_memory
810
from huggingface_hub import snapshot_download
911
from safetensors.torch import save_file
1012
from transformers import AutoModelForCausalLM, PreTrainedModel
1113
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
1214
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
13-
from accelerate import dispatch_model, infer_auto_device_map
14-
from accelerate.utils import get_balanced_memory
1515

1616
from llmcompressor.utils.helpers import patch_attr
1717

18-
__all__ = ["skip_weights_download", "patch_transformers_logger_level", "dispatch_for_generation"]
18+
__all__ = [
19+
"skip_weights_download",
20+
"patch_transformers_logger_level",
21+
"dispatch_for_generation",
22+
]
1923

2024

2125
@contextlib.contextmanager
@@ -114,8 +118,8 @@ def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
114118
max_memory = get_balanced_memory(
115119
model,
116120
dtype=model.dtype,
117-
no_split_module_classes=model._get_no_split_modules("auto")
121+
no_split_module_classes=model._get_no_split_modules("auto"),
118122
)
119123
device_map = infer_auto_device_map(model, dtype=model.dtype, max_memory=max_memory)
120124

121-
return dispatch_model(model, device_map=device_map)
125+
return dispatch_model(model, device_map=device_map)

0 commit comments

Comments
 (0)