Skip to content

Commit 2342f16

Browse files
tjohnson31415njhill
andcommitted
feat: validate that prompts do not have nonfinite values
Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 36052c1 commit 2342f16

File tree

2 files changed

+87
-9
lines changed

2 files changed

+87
-9
lines changed

server/tests/test_prompt_cache.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def temp_prompt_cache():
2020
return prompt_cache.PrefixCache(
2121
device=DEVICE,
2222
dtype=torch.float32,
23-
max_length=256,
23+
max_length=8,
2424
encoder_decoder=False,
2525
decoder_start_tok_embedding=None
2626
)
@@ -264,3 +264,72 @@ def test_get_cache_len(mock_load_tensors, temp_prompt_cache):
264264
temp_prompt_cache.get("prompt1")
265265
temp_prompt_cache.get("prompt2")
266266
assert len(temp_prompt_cache) == 2
267+
268+
### Test cases for invalid prompts
269+
@patch("pathlib.Path.is_file")
270+
def test_prompt_not_exist(mock_is_file, temp_prompt_cache):
271+
mock_is_file.return_value = False
272+
with pytest.raises(Exception):
273+
temp_prompt_cache.get("bad_prompt")
274+
assert len(temp_prompt_cache) == 0
275+
276+
@patch("torch.load")
277+
@patch("pathlib.Path.is_file")
278+
def test_prompt_with_wrong_dims(mock_is_file, mock_torch_load, temp_prompt_cache):
279+
mock_is_file.return_value = True
280+
281+
# one dimension is not enough
282+
mock_torch_load.return_value = torch.ones((3))
283+
with pytest.raises(Exception):
284+
temp_prompt_cache.get("bad_prompt")
285+
assert len(temp_prompt_cache) == 0
286+
287+
# three dimensions is too many
288+
mock_torch_load.return_value = torch.ones((3, 3, 3))
289+
with pytest.raises(Exception):
290+
temp_prompt_cache.get("bad_prompt")
291+
assert len(temp_prompt_cache) == 0
292+
293+
@patch("torch.load")
294+
@patch("pathlib.Path.is_file")
295+
def test_prompt_too_many_virtual_tokens(mock_is_file, mock_torch_load, temp_prompt_cache):
296+
mock_is_file.return_value = True
297+
298+
mock_torch_load.return_value = torch.ones((9,16))
299+
with pytest.raises(Exception):
300+
temp_prompt_cache.get("bad_prompt")
301+
assert len(temp_prompt_cache) == 0
302+
303+
@patch("torch.load")
304+
@patch("pathlib.Path.is_file")
305+
def test_prompt_wrong_embed_size(mock_is_file, mock_torch_load, temp_prompt_cache):
306+
mock_is_file.return_value = True
307+
# set embed_size to 16
308+
temp_prompt_cache.embed_size = 16
309+
mock_torch_load.return_value = torch.ones((1,15))
310+
with pytest.raises(Exception):
311+
temp_prompt_cache.get("bad_prompt")
312+
assert len(temp_prompt_cache) == 0
313+
314+
@patch("torch.load")
315+
@patch("pathlib.Path.is_file")
316+
def test_prompt_with_infinite_after_conversion(mock_is_file, mock_torch_load, temp_prompt_cache):
317+
mock_is_file.return_value = True
318+
bad_tensor = torch.ones((3,3), dtype=torch.float64)
319+
bad_tensor[1, 1] = torch.finfo(torch.float64).max
320+
mock_torch_load.return_value = bad_tensor
321+
with pytest.raises(Exception) as e:
322+
temp_prompt_cache.get("bad_prompt")
323+
assert e.match("torch.float64 to torch.float32")
324+
assert len(temp_prompt_cache) == 0
325+
326+
@patch("torch.load")
327+
@patch("pathlib.Path.is_file")
328+
def test_prompt_with_nan(mock_is_file, mock_torch_load, temp_prompt_cache):
329+
mock_is_file.return_value = True
330+
bad_tensor = torch.ones((3,3), dtype=torch.float16)
331+
bad_tensor[1, 1] = torch.nan
332+
mock_torch_load.return_value = bad_tensor
333+
with pytest.raises(Exception):
334+
temp_prompt_cache.get("bad_prompt")
335+
assert len(temp_prompt_cache) == 0

server/text_generation_server/prompt_cache.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,20 +207,20 @@ def _load_embedding_tensors(self, prefix_id: str) -> Union[torch.Tensor, Tuple[t
207207
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
208208
Loaded encoder / decoder prompt tensor for the model under consideration.
209209
"""
210-
decoder_prefix = self._load_embedding_tensor(prefix_id, "decoder.pt")
210+
decoder_prefix = self._load_embedding_tensor(prefix_id, "decoder.pt", dtype=self.dtype)
211211
# For encoder-decoder we store a tuple of (encoder_prefix, decoder_prefix),
212212
# at least one must be non-None
213213
if self.is_encoder_decoder:
214-
encoder_prefix = self._load_embedding_tensor(prefix_id, "encoder.pt")
214+
encoder_prefix = self._load_embedding_tensor(prefix_id, "encoder.pt", dtype=self.dtype)
215215
if decoder_prefix is None:
216216
if encoder_prefix is None:
217217
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
218218
else:
219219
# TODO confirm this cat is correct
220220
decoder_prefix = torch.cat((decoder_prefix, self.decoder_start_tok_embedding))
221-
decoder_prefix = decoder_prefix.to(self.dtype).to(self.device, non_blocking=True)
221+
decoder_prefix = decoder_prefix.to(self.device, non_blocking=True)
222222
if encoder_prefix is not None:
223-
encoder_prefix = encoder_prefix.to(self.dtype).to(self.device, non_blocking=True)
223+
encoder_prefix = encoder_prefix.to(self.device, non_blocking=True)
224224
prefix = encoder_prefix, decoder_prefix
225225
# For decoder-only we store just the decoder prefix
226226
elif decoder_prefix is None:
@@ -229,7 +229,7 @@ def _load_embedding_tensors(self, prefix_id: str) -> Union[torch.Tensor, Tuple[t
229229
prefix = decoder_prefix.to(self.dtype).to(self.device, non_blocking=True)
230230
return prefix
231231

232-
def _load_embedding_tensor(self, prefix_id: str, filename: str) -> torch.Tensor:
232+
def _load_embedding_tensor(self, prefix_id: str, filename: str, dtype: torch.dtype) -> torch.Tensor:
233233
"""Load an embedding tensor from a single file.
234234
235235
Args:
@@ -264,9 +264,18 @@ def _load_embedding_tensor(self, prefix_id: str, filename: str) -> torch.Tensor:
264264
raise Exception(
265265
f"Prefix embedding tensor dim {prefix.shape[1]} does not match model ({self.embed_size})"
266266
)
267-
268-
prefix.requires_grad = False
269-
return prefix
267+
# convert to the desired dtype
268+
converted_prefix = prefix.to(dtype)
269+
# detect if we have non-finite elements after the conversion that will
270+
# cause problems for inference
271+
if not converted_prefix.isfinite().all():
272+
# check if the problem was in the pre-converted tensor
273+
if not prefix.isfinite().all():
274+
raise Exception(f"Prefix contains non-finite elements")
275+
raise Exception(f"Prefix contains non-finite elements after conversion from {prefix.dtype} to {dtype}")
276+
277+
converted_prefix.requires_grad = False
278+
return converted_prefix
270279

271280
def _add_prefix_id_to_cache(
272281
self,

0 commit comments

Comments
 (0)