Skip to content

Commit 8fd3a58

Browse files
maxdebaysernjhill
authored andcommitted
Add except clause to catch ONNX OOM Exception
Add except clause to catch ONNX OOM Exception
1 parent 8b75134 commit 8fd3a58

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

server/text_generation_server/utils/memory_characterizer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def manual_linear(cls, safety_margin: int, max_seq_len: int, max_batch_size: int
125125
return cls(TOTAL_MEMORY, safety_margin, [linear_param], [0, 0], [linear_param, linear_param])
126126

127127

128+
def is_oom(e: BaseException):
129+
# type: ignore
130+
return isinstance(e, torch.cuda.OutOfMemoryError) or \
131+
isinstance(e, RuntimeError) and "Failed to allocate memory" in str(e)
128132

129133

130134
class Estimator:
@@ -225,6 +229,7 @@ def _sort_samples(self):
225229
def needs_nt(self):
226230
return self.remaining_new_token_samples > 0 and self.enable_nt_sampling
227231

232+
228233
def _run_prefill_test(self, seq_length, min_max_tokens):
229234
try:
230235
gc.collect()
@@ -241,8 +246,10 @@ def _run_prefill_test(self, seq_length, min_max_tokens):
241246
batch = self.model.batch_type.concatenate([batch])
242247
mem_used = torch.cuda.max_memory_allocated(self.model.device)
243248
return False, mem_used, batch
244-
except torch.cuda.OutOfMemoryError: # type: ignore
245-
return True, None, None
249+
except BaseException as e:
250+
if is_oom(e):
251+
return True, None, None
252+
raise
246253

247254
def _run_next_token_test(self, batch, out_seq):
248255
ret = []
@@ -254,9 +261,9 @@ def _run_next_token_test(self, batch, out_seq):
254261
batch = self.model.batch_type.concatenate([batch])
255262
mem_used = torch.cuda.max_memory_allocated(self.model.device)
256263
ret.append(mem_used)
257-
except torch.cuda.OutOfMemoryError: # type: ignore
258-
pass
259-
264+
except BaseException as e:
265+
if not is_oom(e):
266+
raise
260267
return ret
261268

262269
def sample_next_token(self, batch, input_seq_len):

0 commit comments

Comments
 (0)