@@ -125,6 +125,10 @@ def manual_linear(cls, safety_margin: int, max_seq_len: int, max_batch_size: int
125
125
return cls (TOTAL_MEMORY , safety_margin , [linear_param ], [0 , 0 ], [linear_param , linear_param ])
126
126
127
127
128
+ def is_oom (e : BaseException ):
129
+ # type: ignore
130
+ return isinstance (e , torch .cuda .OutOfMemoryError ) or \
131
+ isinstance (e , RuntimeError ) and "Failed to allocate memory" in str (e )
128
132
129
133
130
134
class Estimator :
@@ -225,6 +229,7 @@ def _sort_samples(self):
225
229
def needs_nt (self ):
226
230
return self .remaining_new_token_samples > 0 and self .enable_nt_sampling
227
231
232
+
228
233
def _run_prefill_test (self , seq_length , min_max_tokens ):
229
234
try :
230
235
gc .collect ()
@@ -241,8 +246,10 @@ def _run_prefill_test(self, seq_length, min_max_tokens):
241
246
batch = self .model .batch_type .concatenate ([batch ])
242
247
mem_used = torch .cuda .max_memory_allocated (self .model .device )
243
248
return False , mem_used , batch
244
- except torch .cuda .OutOfMemoryError : # type: ignore
245
- return True , None , None
249
+ except BaseException as e :
250
+ if is_oom (e ):
251
+ return True , None , None
252
+ raise
246
253
247
254
def _run_next_token_test (self , batch , out_seq ):
248
255
ret = []
@@ -254,9 +261,9 @@ def _run_next_token_test(self, batch, out_seq):
254
261
batch = self .model .batch_type .concatenate ([batch ])
255
262
mem_used = torch .cuda .max_memory_allocated (self .model .device )
256
263
ret .append (mem_used )
257
- except torch . cuda . OutOfMemoryError : # type: ignore
258
- pass
259
-
264
+ except BaseException as e :
265
+ if not is_oom ( e ):
266
+ raise
260
267
return ret
261
268
262
269
def sample_next_token (self , batch , input_seq_len ):
0 commit comments