Skip to content

Commit 6403b0c

Browse files
maxdebaysernjhill
authored andcommitted
Log the maximum sequence length that fits into memory at startup
Refactor functions to calculate these upper bounds to use inverse functions instead of allocating large numpy arrays and searching for the highest value that fits.
1 parent 352a2ed commit 6403b0c

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

server/text_generation_server/server.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
import os
3+
import os, sys
44
import threading
55
import time
66
from datetime import datetime
@@ -351,6 +351,16 @@ def estimate_memory():
351351
compile()
352352
memory_scaling_model = estimate_memory()
353353
compile()
354+
355+
max_input = memory_scaling_model.max_input_len_for_nt(1, max_sequence_length-1, sys.maxsize)
356+
max_output = memory_scaling_model.max_output_len_for_nt(1, max_sequence_length-1, sys.maxsize)
357+
358+
if local_rank == 0:
359+
print(
360+
"Maximum possible sequence length given available memory (for batch size 1): "
361+
f"{min(max_input, max_output)}"
362+
)
363+
354364
elif ESTIMATE_MEMORY == "manual":
355365
batch_padding = not isinstance(model, FlashCausalLM)
356366
if batch_padding:

server/text_generation_server/utils/memory_characterizer.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,32 +57,45 @@ def as_pb(self):
5757
nexttoken_linear_coef1=self.next_token_params[1],
5858
)
5959

60-
def __prefill_memory_usage(self, batch_size, input_len):
60+
def prefill_memory_usage(self, batch_size, input_len):
6161
out1 = batch_size * self.linear_fit_params[0] * input_len
6262
bs_seq = input_len * batch_size
6363
out2 = self.quadratic_fit_params[0] * bs_seq + input_len * self.quadratic_fit_params[1] * bs_seq
6464
return np.maximum(out1, out2)
6565

66-
def __nt_memory_usage(self, batch_size, input_len, output_len):
66+
def inverse_linear_prefill(self, batch, mem):
67+
return mem/(self.linear_fit_params[0]*batch)
68+
69+
def inverse_quadratic_prefill(self, batch, mem):
70+
c0, c1 = self.quadratic_fit_params
71+
return (np.sqrt(c0**2 + 4*c1*(mem/batch)) - c0)/(2*c1)
72+
73+
def inverse_prefill(self,batch, mem):
74+
linear = self.inverse_linear_prefill(batch,mem)
75+
quad = self.inverse_quadratic_prefill(batch, mem)
76+
return min(linear, quad)
77+
78+
def nt_memory_usage(self, batch_size, input_len, output_len):
6779
return batch_size * self.next_token_params[0] * input_len + batch_size * self.next_token_params[1] * output_len
6880

81+
def inverse_next_token_output(self, batch, in_seq, mem):
82+
return (mem - self.next_token_params[0]*batch*in_seq)/(batch*self.next_token_params[1])
83+
84+
def inverse_next_token_input(self, batch, out_seq, mem):
85+
return (mem - self.next_token_params[1]*batch*out_seq)/(batch*self.next_token_params[0])
86+
6987
def max_input_len_for_prefill(self, batch_size, max_input_len):
70-
x = np.arange(1, 1+max_input_len)
71-
mem_usage = self.__prefill_memory_usage(batch_size, x)
72-
ind = np.argwhere(mem_usage < self.weight_limit)[-1][0]
73-
return x[ind]
88+
mem_max = np.floor(self.inverse_prefill(batch_size, self.weight_limit))
89+
return int(min(mem_max, max_input_len))
7490

7591
def max_input_len_for_nt(self, batch_size, output_len, max_input_len):
76-
x = np.arange(1, 1+max_input_len)
77-
mem_usage = self.__nt_memory_usage(batch_size, x, output_len)
78-
ind = np.argwhere(mem_usage < self.weight_limit)[-1][0]
79-
return np.minimum(x[ind], self.max_input_len_for_prefill(batch_size, max_input_len))
92+
nt = max(0, np.floor(self.inverse_next_token_input(batch_size, output_len, self.weight_limit)))
93+
prefill = self.max_input_len_for_prefill(batch_size, max_input_len)
94+
return int(min(nt,prefill,max_input_len))
8095

8196
def max_output_len_for_nt(self, batch_size, input_len, max_output_len):
82-
x = np.arange(1, 1+max_output_len)
83-
mem_usage = self.__nt_memory_usage(batch_size, input_len, x)
84-
ind = np.argwhere(mem_usage < self.weight_limit)[-1][0]
85-
return x[ind]
97+
nt = max(0, np.floor(self.inverse_next_token_output(batch_size, input_len, self.weight_limit)))
98+
return int(min(nt, max_output_len))
8699

87100
@classmethod
88101
def disabled(cls):

0 commit comments

Comments
 (0)