Skip to content

Commit 555c360

Browse files
committed
Update TP example
1 parent 4117daa commit 555c360

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

examples/inference_tp.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,22 @@
1010
config.arch_compat_overrides()
1111
config.no_graphs = True
1212
model = ExLlamaV2(config)
13-
model.load_tp(progress = True)
13+
14+
# Load the model in tensor-parallel mode. With no gpu_split specified, the model will attempt to split across
15+
# all visible devices according to the currently available VRAM on each. expect_cache_tokens is necessary for
16+
# balancing the split, in case the GPUs are of uneven sizes, or if the number of GPUs doesn't divide the number
17+
# of KV heads in the model
18+
#
19+
# The cache type for a TP model is always ExLlamaV2Cache_TP and should be allocated after the model. To use a
20+
# quantized cache, add a `base = ExLlamaV2Cache_Q6` etc. argument to the cache constructor. It's advisable
21+
# to also add `expect_cache_base = ExLlamaV2Cache_Q6` to load_tp() as well so the size can be correctly
22+
# accounted for when splitting the model.
23+
24+
model.load_tp(progress = True, expect_cache_tokens = 16384)
1425
cache = ExLlamaV2Cache_TP(model, max_seq_len = 16384)
1526

27+
# After loading the model, all other functions should work the same
28+
1629
print("Loading tokenizer...")
1730
tokenizer = ExLlamaV2Tokenizer(config)
1831

0 commit comments

Comments
 (0)