2
2
import logging
3
3
import os
4
4
import tempfile
5
- from typing import Type , Dict , Any , Union
5
+ from typing import Type
6
6
7
7
import torch
8
+ from accelerate import dispatch_model , infer_auto_device_map
9
+ from accelerate .utils import get_balanced_memory
8
10
from huggingface_hub import snapshot_download
9
11
from safetensors .torch import save_file
10
12
from transformers import AutoModelForCausalLM , PreTrainedModel
11
13
from transformers .modeling_utils import TORCH_INIT_FUNCTIONS
12
14
from transformers .utils import SAFE_WEIGHTS_INDEX_NAME , WEIGHTS_INDEX_NAME
13
- from accelerate import dispatch_model , infer_auto_device_map
14
- from accelerate .utils import get_balanced_memory
15
15
16
16
from llmcompressor .utils .helpers import patch_attr
17
17
18
- __all__ = ["skip_weights_download" , "patch_transformers_logger_level" , "dispatch_for_generation" ]
18
+ __all__ = [
19
+ "skip_weights_download" ,
20
+ "patch_transformers_logger_level" ,
21
+ "dispatch_for_generation" ,
22
+ ]
19
23
20
24
21
25
@contextlib .contextmanager
@@ -114,8 +118,8 @@ def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
114
118
max_memory = get_balanced_memory (
115
119
model ,
116
120
dtype = model .dtype ,
117
- no_split_module_classes = model ._get_no_split_modules ("auto" )
121
+ no_split_module_classes = model ._get_no_split_modules ("auto" ),
118
122
)
119
123
device_map = infer_auto_device_map (model , dtype = model .dtype , max_memory = max_memory )
120
124
121
- return dispatch_model (model , device_map = device_map )
125
+ return dispatch_model (model , device_map = device_map )
0 commit comments