Skip to content

Commit cbeb36f

Browse files
committed
Refine batch packing logic related to prefill weight limit
The current batch packing logic does not take the space used by an in-progress batch into account when evaluating the configured prefill weight limit for a candidate add-on batch. This can be significant in cases where the prefill memory cost is significantly higher than that of incremental new token generation for a given batch size. This fix ensures the effective prefill limit for add-on batches is reduced in proportion to the weight of the existing in-progress batch.
1 parent d980f99 commit cbeb36f

File tree

3 files changed

+50
-30
lines changed

3 files changed

+50
-30
lines changed

router/src/batch_types.rs

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static {
99

1010
/// Update batch statistics with an additional request
1111
fn update_stats(stats: &Self::Stats, input_length: usize, output_length: usize) -> Self::Stats;
12-
/// Calculate batch weight given batch statistics
13-
fn batch_weight(stats: &Self::Stats, batch_size: usize) -> usize;
12+
/// Calculate worst-case max batch weight given batch statistics
13+
fn batch_max_weight(stats: &Self::Stats, batch_size: usize) -> usize;
14+
/// Calculate initial max batch weight given batch statistics (based on input lengths only)
15+
fn batch_initial_weight(stats: &Self::Stats, batch_size: usize) -> usize;
1416
/// Calculate prefill batch weight given prefill batch statistics
1517
fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize;
1618
/// Indicate whether a hypothetical batch will exceed the combined weight limit
@@ -44,21 +46,29 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static {
4446
pub(crate) struct FlashBatch {}
4547

4648
impl BatchType for FlashBatch {
47-
/// Keep track of total number of tokens in the batch
48-
type Stats = usize;
49+
/// Keep track of total number of input and output tokens in the batch
50+
type Stats = (usize, usize);
4951

5052
fn update_stats(
5153
total_tokens: &Self::Stats, input_length: usize, output_length: usize
5254
) -> Self::Stats {
53-
total_tokens + input_length + output_length
55+
let (total_in_tokens, total_out_tokens) = total_tokens;
56+
(total_in_tokens + input_length, total_out_tokens + output_length)
57+
}
58+
59+
fn batch_max_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
60+
let (total_in_tokens, total_out_tokens) = total_tokens;
61+
total_in_tokens + total_out_tokens
5462
}
5563

56-
fn batch_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
57-
*total_tokens
64+
fn batch_initial_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
65+
let (total_in_tokens, _) = total_tokens;
66+
*total_in_tokens
5867
}
5968

6069
fn prefill_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
61-
*total_tokens
70+
let (total_in_tokens, _) = total_tokens;
71+
*total_in_tokens
6272
}
6373

6474
fn exceeds_weight(
@@ -106,13 +116,18 @@ impl BatchType for PaddedBatch {
106116
(max(*max_input_length, input_length), max(*max_output_length, output_length))
107117
}
108118

109-
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
119+
fn batch_max_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
110120
let (max_input_length, max_output_length) = max_in_out_lengths;
111121
let max_seq_len = max_input_length + max_output_length;
112122
// Memory requirement roughly proportional to batch_size * seq_len^2
113123
batch_size * max_seq_len.pow(2)
114124
}
115125

126+
fn batch_initial_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
127+
let (max_input_length, _) = max_in_out_lengths;
128+
batch_size * max_input_length
129+
}
130+
116131
fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
117132
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
118133
let (max_input_length, _) = max_in_out_lengths;

router/src/queue.rs

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,19 @@ impl<B: BatchType> Queue<B> {
232232
let now = Instant::now();
233233
let mut batch_stats = <B>::compute_stats(entries);
234234
let mut prefill_stats = <B>::compute_stats(&self.empty_map);
235+
236+
// Compute the effective prefill weight limit, taking into account space already consumed
237+
// by the in-progress batch
238+
let effective_prefill_weight_limit = match self.config.prefill_weight_limit {
239+
prefill_limit if prefill_limit == 0 || total_count == 0 => prefill_limit,
240+
prefill_limit => {
241+
let current_batch_weight = <B>::batch_initial_weight(&batch_stats, total_count);
242+
let pct_space_free = 1.0 - (
243+
current_batch_weight as f64 / self.config.weight_limit as f64
244+
);
245+
(pct_space_free * prefill_limit as f64) as usize
246+
},
247+
};
235248
// We first do a read-only pass over the queue to allow skipping over large entries
236249
// that don't fit in the current batch to reach smaller entries that do
237250
for (index, entry) in self.buffer.iter().enumerate() {
@@ -247,7 +260,7 @@ impl<B: BatchType> Queue<B> {
247260
);
248261

249262
// Avoid more granular analysis if possible
250-
if <B>::batch_weight(&batch_stats, total_count + 1) > config.weight_limit {
263+
if <B>::batch_max_weight(&batch_stats, total_count + 1) > config.weight_limit {
251264
// We aren't sure whether this next request will fit, so populate
252265
// a btree with the current batch of requests, the set of
253266
// requests already evaluated, and this one, and perform more
@@ -274,9 +287,7 @@ impl<B: BatchType> Queue<B> {
274287
tree.insert((output_len, input_len, tree.len()));
275288

276289
// Perform analysis
277-
if <B>::exceeds_weight(
278-
tree, config.weight_limit, output_len,
279-
) {
290+
if <B>::exceeds_weight(tree, config.weight_limit, output_len) {
280291
if chosen_indices.len() + buffer_size < min_size + index + 1 {
281292
// We don't have enough remaining to meet min_size
282293
return None
@@ -296,28 +307,22 @@ impl<B: BatchType> Queue<B> {
296307
metrics::increment_counter!("tgi_queue_jump");
297308
}
298309

299-
// Also check whether adding this request will make the batch of new requests
300-
// too expensive latency-wise to perform in a single forward-pass.
301-
let mut prefill_weight_exceeded = false;
302-
if config.prefill_weight_limit > 0 {
310+
// Also check whether adding this request will breach the prefill weight limit
311+
if effective_prefill_weight_limit > 0 {
303312
let next_prefill_stats = <B>::update_stats(
304313
&prefill_stats, input_len, 0
305314
);
306315
let prefill_weight = <B>::prefill_weight(
307316
&next_prefill_stats, chosen_indices.len() + 1
308317
);
309-
if prefill_weight > config.prefill_weight_limit {
310-
if chosen_indices.is_empty() {
311-
prefill_weight_exceeded = true;
312-
} else {
313-
if let Some(tree) = btree.as_mut() {
314-
// Remove our tuple from the set
315-
tree.remove(&(output_len, input_len, tree.len() - 1));
316-
}
317-
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
318-
metrics::increment_counter!("tgi_prefill_weight_limit_exceeded");
319-
continue
318+
if prefill_weight > effective_prefill_weight_limit {
319+
if let Some(tree) = btree.as_mut() {
320+
// Remove our tuple from the set
321+
tree.remove(&(output_len, input_len, tree.len() - 1));
320322
}
323+
time_cutoff.get_or_insert_with(|| entry.queue_time.add(CUTOFF_DURATION));
324+
metrics::increment_counter!("tgi_prefill_weight_limit_exceeded");
325+
continue
321326
}
322327
prefill_stats = next_prefill_stats;
323328
}
@@ -326,7 +331,7 @@ impl<B: BatchType> Queue<B> {
326331

327332
chosen_indices.push(index);
328333
total_count += 1;
329-
if total_count >= config.size_limit || prefill_weight_exceeded {
334+
if total_count >= config.size_limit {
330335
break
331336
}
332337
}

router/src/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ impl<B: BatchType> BatchConfigValidator<B> {
190190
let single_request_stats = <B>::update_stats(
191191
&B::Stats::default(), max_sequence_length, 0
192192
);
193-
let single_request_weight = <B>::batch_weight(
193+
let single_request_weight = <B>::batch_initial_weight(
194194
&single_request_stats, 1
195195
);
196196
let weight_upper_bound = single_request_weight * max_batch_size;

0 commit comments

Comments
 (0)