Skip to content

Commit d1d2bbd

Browse files
committed
Fix to recent batch-packing prefill refinements
Missing pow(2) in padded batch_initial_weight calculation
1 parent 309b19d commit d1d2bbd

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

router/src/batch_types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ impl BatchType for PaddedBatch {
125125

126126
fn batch_initial_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
127127
let (max_input_length, _) = max_in_out_lengths;
128-
batch_size * max_input_length
128+
batch_size * max_input_length.pow(2)
129129
}
130130

131131
fn prefill_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {

server/text_generation_server/models/causal_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
249249
# and to remove unused allocated space
250250
left_offset = max_sequence_length - batch.max_sequence_length
251251
batch_left_offset = (
252-
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
252+
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
253253
)
254254
attention_mask[
255255
start_index:end_index, left_offset:-padding_right_offset,

0 commit comments

Comments
 (0)