Skip to content

Commit f27b7bc

Browse files
committed
Rename hf_custom_tp engine to tgis_native
And retain configuration backwards compatibility
1 parent f4c1f04 commit f27b7bc

File tree

10 files changed

+26
-17
lines changed

10 files changed

+26
-17
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ The following model types can currently be run in sharded mode where the weights
114114

115115
1. Ensure that the model weights are in `safetensors format (see above)
116116
2. Ensure that the `CUDA_VISIBLE_DEVICES` environment variable is set appropriately (e.g. "0,1" to use the first two GPUs). The number of GPUs to use will be inferred from this or else can be set explicitly with the `NUM_GPUS` environment variable.
117-
3. Set the environment variable `DEPLOYMENT_FRAMEWORK=hf_custom_tp`
117+
3. Set the environment variable `DEPLOYMENT_FRAMEWORK=tgis_native`
118118

119119
### TLS configuration
120120

deployment/base/patches/flash-attention.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ spec:
99
- name: server
1010
env:
1111
- name: DEPLOYMENT_FRAMEWORK
12-
value: hf_custom_tp
12+
value: tgis_native
1313
- name: FLASH_ATTENTION
1414
value: "true"

deployment/models/bloom/kustomization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ patchesStrategicMerge:
2828
- name: MODEL_NAME
2929
value: bigscience/bloom
3030
- name: DEPLOYMENT_FRAMEWORK
31-
value: hf_custom_tp
31+
value: tgis_native
3232
3333
- name: MAX_BATCH_SIZE
3434
value: "16"

deployment/models/bloomchat-v1/kustomization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ patchesStrategicMerge:
2828
- name: MODEL_NAME
2929
value: sambanovasystems/BLOOMChat-176B-v1
3030
- name: DEPLOYMENT_FRAMEWORK
31-
value: hf_custom_tp
31+
value: tgis_native
3232
- name: DTYPE_STR
3333
value: float16
3434
- name: MAX_BATCH_SIZE

deployment/models/bloomz/kustomization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ patchesStrategicMerge:
2828
- name: MODEL_NAME
2929
value: bigscience/bloomz
3030
- name: DEPLOYMENT_FRAMEWORK
31-
value: hf_custom_tp
31+
value: tgis_native
3232
- name: DTYPE_STR
3333
value: float16
3434
- name: MAX_BATCH_SIZE

deployment/models/flan-ul2-tp/kustomization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ patchesStrategicMerge:
2828
- name: MODEL_NAME
2929
value: google/flan-ul2
3030
- name: DEPLOYMENT_FRAMEWORK
31-
value: hf_custom_tp
31+
value: tgis_native
3232

launcher/src/main.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@ fn main() -> ExitCode {
162162
Err(VarError::NotUnicode(_)) => panic!("PYTORCH_CUDA_ALLOC_CONF set to non-unicode value"),
163163
};
164164

165+
// Backwards compatibility for "hf_custom_tp" deployment engine name
166+
let deployment_framework = if args.deployment_framework == "hf_custom_tp" {
167+
warn!("The \"hf_custom_tp\" deployment engine name is deprecated, please use \"tgis_native\"");
168+
"tgis_native"
169+
} else {
170+
&args.deployment_framework
171+
};
172+
165173
// Signal handler
166174
let running = Arc::new(AtomicBool::new(true));
167175
let r = running.clone();
@@ -182,6 +190,7 @@ fn main() -> ExitCode {
182190
// Start shard processes
183191
for rank in 0..num_shard {
184192
let args = args.clone();
193+
let deployment_framework = deployment_framework.to_string();
185194
let status_sender = status_sender.clone();
186195
let shutdown = shutdown.clone();
187196
let shutdown_sender = shutdown_sender.clone();
@@ -190,7 +199,7 @@ fn main() -> ExitCode {
190199
shard_manager(
191200
args.model_name,
192201
args.revision,
193-
args.deployment_framework,
202+
deployment_framework,
194203
args.dtype.or(args.dtype_str),
195204
args.quantize,
196205
max_sequence_length,

server/text_generation_server/inference_engine/hf_custom_tp.py renamed to server/text_generation_server/inference_engine/tgis_native.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
f"Flash attention currently only supported by the following model types: {NONTP_FLASH_TYPES}"
5454
)
5555
elif model_type not in NONTP_NONFLASH_TYPES:
56-
raise ValueError("hf_custom_tp engine must be used with FLASH_ATTENTION, num_shards > 1 and/or BLOOM or T5")
56+
raise ValueError("tgis_native engine must be used with FLASH_ATTENTION, num_shards > 1 and/or BLOOM or T5")
5757

5858
aliases = None
5959

@@ -105,7 +105,7 @@ def __init__(
105105
torch.distributed.barrier(group=self.process_group)
106106
filenames = local_weight_files(model_path, extension=".safetensors")
107107
if not filenames:
108-
raise ValueError("No safetensors weights found - required for hf_custom_tp engine")
108+
raise ValueError("No safetensors weights found - required for tgis_native engine")
109109

110110
weights = Weights(
111111
filenames, device=self.device, dtype=dtype, process_group=self.process_group, aliases=aliases

server/text_generation_server/models/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ def get_model(
4545
import text_generation_server.utils.flash_attn as flash_attn
4646
print(f"Using Flash Attention V2: {flash_attn.HAS_FLASH_ATTN_V2}")
4747

48-
if deployment_framework != "hf_custom_tp":
48+
if deployment_framework != "tgis_native":
4949
print_rank_n(
50-
f"WARNING: Using deployment engine hf_custom_tp rather than {deployment_framework} "
50+
f"WARNING: Using deployment engine tgis_native rather than {deployment_framework} "
5151
"because FLASH_ATTENTION is enabled"
5252
)
53-
deployment_framework = "hf_custom_tp"
53+
deployment_framework = "tgis_native"
5454

5555
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
5656
# Custom config type for RW models
@@ -75,10 +75,10 @@ def get_model(
7575

7676
elif deployment_framework == "hf_transformers" and int(os.getenv("WORLD_SIZE", "1")) > 1:
7777
print_rank_n(
78-
f"WARNING: Using deployment engine hf_custom_tp rather than {deployment_framework} "
78+
f"WARNING: Using deployment engine tgis_native rather than {deployment_framework} "
7979
"because more than one shard is configured"
8080
)
81-
deployment_framework = "hf_custom_tp"
81+
deployment_framework = "tgis_native"
8282

8383
supports_causal_lm = model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES \
8484
or type(model_config) in AutoModelForCausalLM._model_mapping \
@@ -95,9 +95,9 @@ def get_model(
9595
if supports_seq2seq_lm and model_type == "bart":
9696
supports_causal_lm = False
9797

98-
if deployment_framework != "hf_custom_tp" and (model_type == "bloom" or model_type == "t5"):
98+
if deployment_framework != "tgis_native" and (model_type == "bloom" or model_type == "t5"):
9999
print_rank_n(
100-
"WARNING: It's recommended to use the hf_custom_tp engine with safetensors weights for T5 and BLOOM models"
100+
"WARNING: It's recommended to use the tgis_native engine with safetensors weights for T5 and BLOOM models"
101101
)
102102

103103
if supports_causal_lm:

server/text_generation_server/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ async def serve_inner(
276276
print(f"Using device {device}, dtype {dtype_str}, quantize {quantize}")
277277
print(model.config.__str__())
278278

279-
if quantize == "gptq" and deployment_framework == "hf_custom_tp":
279+
if quantize == "gptq" and deployment_framework == "tgis_native":
280280
from text_generation_server.utils.layers import HAS_EXLLAMA, EXLLAMA_VERSION
281281
if HAS_EXLLAMA:
282282
try:

0 commit comments

Comments
 (0)