Skip to content

Commit 024cbe2

Browse files
joerundenjhill
andcommitted
Use exllamav2 by default with autogptq; set seq length for exllama v1
Co-authored-by: Nick Hill <[email protected]>
1 parent c1c58f2 commit 024cbe2

File tree

11 files changed

+61
-17
lines changed

11 files changed

+61
-17
lines changed

server/text_generation_server/inference_engine/ds_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def __init__(
2121
model_class: type[_BaseAutoModelClass],
2222
dtype: torch.dtype,
2323
quantize: Optional[str],
24-
model_config: Optional[Any]
24+
model_config: Optional[Any],
25+
max_sequence_length: Optional[int],
2526
) -> None:
2627
super().__init__(model_path, model_config)
2728

server/text_generation_server/inference_engine/hf_accelerate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010
class InferenceEngine(BaseInferenceEngine):
1111
def __init__(
12-
self,
13-
model_path: str,
14-
model_class: type[_BaseAutoModelClass],
15-
dtype: torch.dtype,
16-
quantize: Optional[str],
17-
model_config: Optional[Any]
12+
self,
13+
model_path: str,
14+
model_class: type[_BaseAutoModelClass],
15+
dtype: torch.dtype,
16+
quantize: Optional[str],
17+
model_config: Optional[Any],
18+
max_sequence_length: Optional[int],
1819
) -> None:
1920
super().__init__(model_path, model_config)
2021

@@ -32,7 +33,7 @@ def __init__(
3233
# using LLM.int8()
3334
kwargs["load_in_8bit"] = True
3435
elif quantize is not None:
35-
raise ValueError(f"{quantize} quantization not supported by hf_transformers engine")
36+
raise ValueError(f"{quantize} quantization not supported by hf_accelerate engine")
3637
else:
3738
kwargs["torch_dtype"] = dtype
3839

server/text_generation_server/inference_engine/hf_custom_tp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
dtype: torch.dtype,
2929
quantize: Optional[str],
3030
model_config: Optional[Any],
31+
max_sequence_length: Optional[int],
3132
) -> None:
3233
super().__init__(model_path, model_config)
3334

server/text_generation_server/inference_engine/hf_optimum_bt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
dtype: torch.dtype,
1717
quantize: Optional[str],
1818
model_config: Optional[Any],
19+
max_sequence_length: Optional[int],
1920
) -> None:
2021
super().__init__(model_path, model_config)
2122

server/text_generation_server/inference_engine/hf_optimum_ort.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
dtype: torch.dtype,
2020
quantize: Optional[str],
2121
model_config: Optional[Any],
22+
max_sequence_length: Optional[int],
2223
) -> None:
2324
super().__init__(model_path, model_config)
2425

server/text_generation_server/inference_engine/hf_transformers.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import torch
3+
from loguru import logger
34
from transformers.models.auto.auto_factory import _BaseAutoModelClass
45

56
from text_generation_server.inference_engine.engine import BaseInferenceEngine
@@ -14,7 +15,8 @@ def __init__(
1415
model_class: type[_BaseAutoModelClass],
1516
dtype: torch.dtype,
1617
quantize: Optional[str],
17-
model_config: Optional[Any]
18+
model_config: Optional[Any],
19+
max_sequence_length: Optional[int] = None,
1820
) -> None:
1921
super().__init__(model_path, model_config)
2022

@@ -32,9 +34,29 @@ def __init__(
3234
model_config.init_device = str(self.device)
3335
kwargs["config"] = model_config
3436

37+
if quantize is None and hasattr(model_config, "quantization_config"):
38+
quantize = model_config.quantization_config.get("quant_method")
39+
3540
if quantize == "bitsandbytes":
3641
# using LLM.int8()
3742
kwargs["load_in_8bit"] = True
43+
44+
elif quantize == "gptq" and model_config.quantization_config.get("bits", 4) == 4:
45+
from transformers import GPTQConfig
46+
47+
logger.info("Using AutoGPTQ to load 4-bit GPTQ model")
48+
kwargs["device_map"] = "auto"
49+
quantization_config = GPTQConfig(bits=4, max_input_length=max_sequence_length)
50+
disable_exllama = os.getenv("DISABLE_EXLLAMA", "False").lower() == "true"
51+
if disable_exllama:
52+
logger.info("Exllama kernels disabled")
53+
quantization_config.use_exllama = False
54+
else:
55+
exllama_version = int(os.getenv("EXLLAMA_VERSION", "2")) # Use v2 as default
56+
logger.info(f"Using exllama version {exllama_version}")
57+
quantization_config.exllama_config = {"version": exllama_version}
58+
kwargs["quantization_config"] = quantization_config
59+
3860
elif quantize is not None:
3961
raise ValueError(f"{quantize} quantization not supported by hf_transformers engine")
4062
else:

server/text_generation_server/models/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828

2929

3030
def get_model(
31-
model_name: str, revision: str, deployment_framework: str, dtype_str: str, quantize: Optional[str]
31+
model_name: str,
32+
revision: str,
33+
deployment_framework: str,
34+
dtype_str: str,
35+
quantize: Optional[str],
36+
max_sequence_length: Optional[int],
3237
) -> Model:
3338
dtype = get_torch_dtype(dtype_str)
3439
model_path = get_model_path(model_name, revision)
@@ -59,7 +64,14 @@ def get_model(
5964
model_config = LlamaConfig.from_pretrained(model_path)
6065

6166
from text_generation_server.models.flash_causal_lm import FlashCausalLM
62-
return FlashCausalLM(model_name, revision, deployment_framework, dtype, quantize, model_config)
67+
return FlashCausalLM(
68+
model_name,
69+
revision,
70+
deployment_framework,
71+
dtype, quantize,
72+
model_config,
73+
max_sequence_length=max_sequence_length,
74+
)
6375

6476
elif deployment_framework == "hf_transformers" and int(os.getenv("WORLD_SIZE", "1")) > 1:
6577
print_rank_n(
@@ -89,9 +101,9 @@ def get_model(
89101
)
90102

91103
if supports_causal_lm:
92-
return CausalLM(model_name, revision, deployment_framework, dtype, quantize, model_config)
104+
return CausalLM(model_name, revision, deployment_framework, dtype, quantize, model_config, max_sequence_length)
93105

94106
if supports_seq2seq_lm:
95-
return Seq2SeqLM(model_name, revision, deployment_framework, dtype, quantize, model_config)
107+
return Seq2SeqLM(model_name, revision, deployment_framework, dtype, quantize, model_config, max_sequence_length)
96108

97109
raise NotImplementedError(f"Unsupported model type {model_type}")

server/text_generation_server/models/causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,11 +551,12 @@ def __init__(
551551
dtype: torch.dtype,
552552
quantize: Optional[str],
553553
model_config: Union[Any] = None,
554+
max_sequence_length: Optional[int] = None,
554555
):
555556
model_path = get_model_path(model_name, revision)
556557

557558
inference_engine = get_inference_engine_class(deployment_framework)(
558-
model_path, AutoModelForCausalLM, dtype, quantize, model_config,
559+
model_path, AutoModelForCausalLM, dtype, quantize, model_config, max_sequence_length
559560
)
560561

561562
super(CausalLM, self).__init__(inference_engine, dtype)

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def __init__(
372372
quantize: Optional[str],
373373
model_config: Union[Any] = None,
374374
auto_model_class=None,
375+
max_sequence_length: Optional[int] = None,
375376
):
376377
if not torch.cuda.is_available():
377378
raise NotImplementedError("FlashCausalLM is only available on GPU")
@@ -381,7 +382,7 @@ def __init__(
381382
model_path = get_model_path(model_name, revision)
382383

383384
inference_engine = get_inference_engine_class(deployment_framework)(
384-
model_path, auto_model_class, dtype, quantize, model_config,
385+
model_path, auto_model_class, dtype, quantize, model_config, max_sequence_length
385386
)
386387

387388
super(FlashCausalLM, self).__init__(inference_engine, dtype)

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,12 @@ def __init__(
550550
dtype: torch.dtype,
551551
quantize: Optional[str],
552552
model_config: Union[Any] = None,
553+
max_sequence_length: Optional[int] = None,
553554
):
554555
model_path = get_model_path(model_name, revision)
555556

556557
inference_engine = get_inference_engine_class(deployment_framework)(
557-
model_path, AutoModelForSeq2SeqLM, dtype, quantize, model_config,
558+
model_path, AutoModelForSeq2SeqLM, dtype, quantize, model_config, max_sequence_length
558559
)
559560
super(Seq2SeqLM, self).__init__(inference_engine, dtype)
560561

0 commit comments

Comments
 (0)