Skip to content

Commit 8a4b195

Browse files
tjohnson31415njhill
authored andcommitted
feat: have hf_accelerate call to hf_transformers
Signed-off-by: Travis Johnson <[email protected]>
1 parent 164f565 commit 8a4b195

File tree

2 files changed

+18
-30
lines changed

2 files changed

+18
-30
lines changed
Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import os
1+
from typing import Any, Optional
2+
23
import torch
34
from transformers.models.auto.auto_factory import _BaseAutoModelClass
45

5-
from text_generation_server.inference_engine.engine import BaseInferenceEngine
6-
from text_generation_server.utils.hub import TRUST_REMOTE_CODE
7-
from typing import Any, Optional
6+
from text_generation_server.inference_engine.hf_transformers import InferenceEngine as HFTransformersInferenceEngine
87

98

10-
class InferenceEngine(BaseInferenceEngine):
9+
class InferenceEngine(HFTransformersInferenceEngine):
1110
def __init__(
1211
self,
1312
model_path: str,
@@ -17,28 +16,12 @@ def __init__(
1716
model_config: Optional[Any],
1817
max_sequence_length: Optional[int],
1918
) -> None:
20-
super().__init__(model_path, model_config)
21-
22-
kwargs = {
23-
"pretrained_model_name_or_path": model_path,
24-
"device_map": None,
25-
"local_files_only": True,
26-
"trust_remote_code": TRUST_REMOTE_CODE,
27-
}
28-
29-
if self.device.type == "cuda":
30-
kwargs["device_map"] = "balanced_low_0" if self.world_size > 1 else "auto"
31-
32-
if quantize == "bitsandbytes":
33-
# using LLM.int8()
34-
kwargs["load_in_8bit"] = True
35-
elif quantize is not None:
36-
raise ValueError(f"{quantize} quantization not supported by hf_accelerate engine")
37-
else:
38-
kwargs["torch_dtype"] = dtype
39-
40-
slow_but_exact = os.getenv('BLOOM_SLOW_BUT_EXACT', 'false').lower() == 'true'
41-
if slow_but_exact:
42-
kwargs["slow_but_exact"] = True
43-
44-
self.model = model_class.from_pretrained(**kwargs).requires_grad_(False).eval()
19+
super().__init__(
20+
model_path,
21+
model_class,
22+
dtype,
23+
quantize,
24+
model_config,
25+
max_sequence_length,
26+
_use_accelerate=True
27+
)

server/text_generation_server/inference_engine/hf_transformers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def __init__(
1717
quantize: Optional[str],
1818
model_config: Optional[Any],
1919
max_sequence_length: Optional[int] = None,
20+
# internal arg only for this engine
21+
_use_accelerate: bool = False,
2022
) -> None:
2123
super().__init__(model_path, model_config)
2224

@@ -26,6 +28,9 @@ def __init__(
2628
"trust_remote_code": TRUST_REMOTE_CODE,
2729
}
2830

31+
if _use_accelerate and self.device.type == "cuda":
32+
kwargs["device_map"]= "balanced_low_0" if self.world_size > 1 else "auto"
33+
2934
# TODO: consider if Flash Attention should be enabled based on FLASH_ATTENTION=True
3035
if attn_impl := os.getenv("TRANSFORMERS_ATTN_IMPL"):
3136
logger.info(f"Setting attn_implementation to {attn_impl}")

0 commit comments

Comments
 (0)