Skip to content

Commit 613c337

Browse files
committed
fix: parallel_download_decorator compat with transformers >= 5, use dtype instead of torch_dtype
1 parent 9fafeb4 commit 613c337

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

qllm/modeling/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,12 @@ def cached_file_func_in_thread(task_func, *args, **kwargs):
128128
return executor.submit(task_func, *args, **kwargs)
129129
transformers.utils.hub.cached_file = functools.partial(cached_file_func_in_thread, transformers.utils.hub.cached_file)
130130
result = task_func_shard(*args, **kwargs)
131-
result_0 = [future.result() for future in result[0]]
131+
result_0 = []
132+
for item in result[0]:
133+
if isinstance(item, str):
134+
result_0.append(item)
135+
else:
136+
result_0.append(item.result())
132137
return result_0, result[1]
133138

134139

@@ -229,7 +234,7 @@ def from_pretrained(
229234
)
230235
llm = AutoModelForCausalLM.from_pretrained(
231236
pretrained_model_name_or_path,
232-
torch_dtype=torch_dtype,
237+
dtype=torch_dtype,
233238
trust_remote_code=trust_remote_code,
234239
attn_implementation=attn_implementation,
235240
# device_map="auto",

0 commit comments

Comments
 (0)