Skip to content

Commit 79d38a1

Browse files
declark1njhill
authored andcommitted
fix: add falcon to supported tp flash types, fix num_groups in FlashRWLargeAttention
This PR fixes two issues to enable falcon-180b: 1. Fixes issue with num_groups resulting in `NotImplementedError: Tensor Parallelism is not implemented for 14 not divisible by 8` when loading falcon-180b with 8 GPUs. 2. Adds "falcon" to supported TP flash types
1 parent addc714 commit 79d38a1

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

server/text_generation_server/inference_engine/hf_custom_tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from text_generation_server.utils.dist import initialize_torch_distributed
1515
from text_generation_server.utils.hub import local_weight_files
1616

17-
NONTP_FLASH_TYPES = ["RefinedWeb", "RefinedWebModel", "gpt_neox", "gpt_bigcode", "llama"]
17+
NONTP_FLASH_TYPES = ["RefinedWeb", "RefinedWebModel", "gpt_neox", "gpt_bigcode", "llama", "falcon"]
1818
TP_NONFLASH_TYPES = ["bloom", "t5", "gpt_neox"]
1919
TP_FLASH_TYPES = NONTP_FLASH_TYPES # All flash types currently support TP
2020
NONTP_NONFLASH_TYPES = ["bloom", "t5"]

server/text_generation_server/models/custom_modeling/flash_rw_modeling.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,24 +228,23 @@ def __init__(
228228

229229
hidden_size = config.hidden_size
230230
num_heads = config.n_head
231-
num_heads_kv = config.n_head_kv
231+
num_groups = config.n_head_kv
232232

233233
self.hidden_size = hidden_size
234234
self.head_size = hidden_size // num_heads
235+
self.num_groups = num_groups
236+
self.num_heads = num_heads // self.num_groups
235237

236238
self.rotary_emb = PositionRotaryEmbedding.static(
237239
self.head_size, base=10000.0, device=weights.device
238240
)
239241
self.softmax_scale = self.head_size ** (-0.5)
240242

241-
self.num_groups = num_heads // (num_heads_kv * 2)
242-
self.num_heads = num_heads // self.num_groups
243-
self.num_heads_kv = num_heads_kv // self.num_groups
244243
process_group = weights.process_group
245244

246245
if process_group.size() > self.num_groups:
247246
raise NotImplementedError(
248-
f"Tensor Parallelism is not implemented for world_size > n groups"
247+
"Tensor Parallelism is not implemented for world_size > n groups"
249248
)
250249
if self.num_groups % process_group.size() != 0:
251250
raise NotImplementedError(

0 commit comments

Comments
 (0)