Skip to content

Commit 36fb68f

Browse files
authored
[Doc] Chunked Prefill Documentation (#4580)
1 parent bc8ad68 commit 36fb68f

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Documentation
8787
models/adding_model
8888
models/engine_args
8989
models/lora
90+
models/performance
9091

9192
.. toctree::
9293
:maxdepth: 1

docs/source/models/performance.rst

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
.. _performance:
2+
3+
Performance and Tuning
4+
======================
5+
6+
Chunked Prefill
7+
---------------
8+
vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests.
9+
10+
You can enable the feature by specifying
11+
12+
.. code-block:: python
13+
14+
llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True)
15+
# Set max_num_batched_tokens to tune performance.
16+
# NOTE: 512 is the default max_num_batched_tokens for chunked prefill.
17+
# llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512)
18+
19+
By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization.
20+
21+
Once chunked prefill is enabled, the policy is changed to
22+
23+
- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill.
24+
- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it.
25+
26+
This policy has two benefits.
27+
28+
- It improves ITL (inter token latency) and generation decode because decode requests are prioritized.
29+
- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch.
30+
31+
You can tune the performance by changing `max_num_batched_tokens`.
32+
By default, it is set to 512, which has the best ITL on A100 in the initial benchmark.
33+
Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes.
34+
Higher batch size achieves better TTFT as you can put more prefill to the batch.
35+
If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes).
36+
Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput.
37+
38+
See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369).

vllm/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,9 @@ def __init__(
607607
self.max_num_batched_tokens = max_num_batched_tokens
608608
else:
609609
if enable_chunked_prefill:
610-
# For chunked prefill, choose the well-tuned batch size.
611-
self.max_num_batched_tokens = 768
610+
# It is the values that have the best balance between ITL
611+
# and TTFT on A100. Note it is not optimized for throughput.
612+
self.max_num_batched_tokens = 512
612613
else:
613614
# If max_model_len is too short, use 2048 as the default value
614615
# for higher throughput.

0 commit comments

Comments
 (0)