@@ -57,32 +57,45 @@ def as_pb(self):
57
57
nexttoken_linear_coef1 = self .next_token_params [1 ],
58
58
)
59
59
60
- def __prefill_memory_usage (self , batch_size , input_len ):
60
+ def prefill_memory_usage (self , batch_size , input_len ):
61
61
out1 = batch_size * self .linear_fit_params [0 ] * input_len
62
62
bs_seq = input_len * batch_size
63
63
out2 = self .quadratic_fit_params [0 ] * bs_seq + input_len * self .quadratic_fit_params [1 ] * bs_seq
64
64
return np .maximum (out1 , out2 )
65
65
66
- def __nt_memory_usage (self , batch_size , input_len , output_len ):
66
+ def inverse_linear_prefill (self , batch , mem ):
67
+ return mem / (self .linear_fit_params [0 ]* batch )
68
+
69
+ def inverse_quadratic_prefill (self , batch , mem ):
70
+ c0 , c1 = self .quadratic_fit_params
71
+ return (np .sqrt (c0 ** 2 + 4 * c1 * (mem / batch )) - c0 )/ (2 * c1 )
72
+
73
+ def inverse_prefill (self ,batch , mem ):
74
+ linear = self .inverse_linear_prefill (batch ,mem )
75
+ quad = self .inverse_quadratic_prefill (batch , mem )
76
+ return min (linear , quad )
77
+
78
+ def nt_memory_usage (self , batch_size , input_len , output_len ):
67
79
return batch_size * self .next_token_params [0 ] * input_len + batch_size * self .next_token_params [1 ] * output_len
68
80
81
+ def inverse_next_token_output (self , batch , in_seq , mem ):
82
+ return (mem - self .next_token_params [0 ]* batch * in_seq )/ (batch * self .next_token_params [1 ])
83
+
84
+ def inverse_next_token_input (self , batch , out_seq , mem ):
85
+ return (mem - self .next_token_params [1 ]* batch * out_seq )/ (batch * self .next_token_params [0 ])
86
+
69
87
def max_input_len_for_prefill (self , batch_size , max_input_len ):
70
- x = np .arange (1 , 1 + max_input_len )
71
- mem_usage = self .__prefill_memory_usage (batch_size , x )
72
- ind = np .argwhere (mem_usage < self .weight_limit )[- 1 ][0 ]
73
- return x [ind ]
88
+ mem_max = np .floor (self .inverse_prefill (batch_size , self .weight_limit ))
89
+ return int (min (mem_max , max_input_len ))
74
90
75
91
def max_input_len_for_nt (self , batch_size , output_len , max_input_len ):
76
- x = np .arange (1 , 1 + max_input_len )
77
- mem_usage = self .__nt_memory_usage (batch_size , x , output_len )
78
- ind = np .argwhere (mem_usage < self .weight_limit )[- 1 ][0 ]
79
- return np .minimum (x [ind ], self .max_input_len_for_prefill (batch_size , max_input_len ))
92
+ nt = max (0 , np .floor (self .inverse_next_token_input (batch_size , output_len , self .weight_limit )))
93
+ prefill = self .max_input_len_for_prefill (batch_size , max_input_len )
94
+ return int (min (nt ,prefill ,max_input_len ))
80
95
81
96
def max_output_len_for_nt (self , batch_size , input_len , max_output_len ):
82
- x = np .arange (1 , 1 + max_output_len )
83
- mem_usage = self .__nt_memory_usage (batch_size , input_len , x )
84
- ind = np .argwhere (mem_usage < self .weight_limit )[- 1 ][0 ]
85
- return x [ind ]
97
+ nt = max (0 , np .floor (self .inverse_next_token_output (batch_size , input_len , self .weight_limit )))
98
+ return int (min (nt , max_output_len ))
86
99
87
100
@classmethod
88
101
def disabled (cls ):
0 commit comments