@@ -20,7 +20,7 @@ def temp_prompt_cache():
20
20
return prompt_cache .PrefixCache (
21
21
device = DEVICE ,
22
22
dtype = torch .float32 ,
23
- max_length = 256 ,
23
+ max_length = 8 ,
24
24
encoder_decoder = False ,
25
25
decoder_start_tok_embedding = None
26
26
)
@@ -264,3 +264,72 @@ def test_get_cache_len(mock_load_tensors, temp_prompt_cache):
264
264
temp_prompt_cache .get ("prompt1" )
265
265
temp_prompt_cache .get ("prompt2" )
266
266
assert len (temp_prompt_cache ) == 2
267
+
268
+ ### Test cases for invalid prompts
269
+ @patch ("pathlib.Path.is_file" )
270
+ def test_prompt_not_exist (mock_is_file , temp_prompt_cache ):
271
+ mock_is_file .return_value = False
272
+ with pytest .raises (Exception ):
273
+ temp_prompt_cache .get ("bad_prompt" )
274
+ assert len (temp_prompt_cache ) == 0
275
+
276
+ @patch ("torch.load" )
277
+ @patch ("pathlib.Path.is_file" )
278
+ def test_prompt_with_wrong_dims (mock_is_file , mock_torch_load , temp_prompt_cache ):
279
+ mock_is_file .return_value = True
280
+
281
+ # one dimension is not enough
282
+ mock_torch_load .return_value = torch .ones ((3 ))
283
+ with pytest .raises (Exception ):
284
+ temp_prompt_cache .get ("bad_prompt" )
285
+ assert len (temp_prompt_cache ) == 0
286
+
287
+ # three dimensions is too many
288
+ mock_torch_load .return_value = torch .ones ((3 , 3 , 3 ))
289
+ with pytest .raises (Exception ):
290
+ temp_prompt_cache .get ("bad_prompt" )
291
+ assert len (temp_prompt_cache ) == 0
292
+
293
+ @patch ("torch.load" )
294
+ @patch ("pathlib.Path.is_file" )
295
+ def test_prompt_too_many_virtual_tokens (mock_is_file , mock_torch_load , temp_prompt_cache ):
296
+ mock_is_file .return_value = True
297
+
298
+ mock_torch_load .return_value = torch .ones ((9 ,16 ))
299
+ with pytest .raises (Exception ):
300
+ temp_prompt_cache .get ("bad_prompt" )
301
+ assert len (temp_prompt_cache ) == 0
302
+
303
+ @patch ("torch.load" )
304
+ @patch ("pathlib.Path.is_file" )
305
+ def test_prompt_wrong_embed_size (mock_is_file , mock_torch_load , temp_prompt_cache ):
306
+ mock_is_file .return_value = True
307
+ # set embed_size to 16
308
+ temp_prompt_cache .embed_size = 16
309
+ mock_torch_load .return_value = torch .ones ((1 ,15 ))
310
+ with pytest .raises (Exception ):
311
+ temp_prompt_cache .get ("bad_prompt" )
312
+ assert len (temp_prompt_cache ) == 0
313
+
314
+ @patch ("torch.load" )
315
+ @patch ("pathlib.Path.is_file" )
316
+ def test_prompt_with_infinite_after_conversion (mock_is_file , mock_torch_load , temp_prompt_cache ):
317
+ mock_is_file .return_value = True
318
+ bad_tensor = torch .ones ((3 ,3 ), dtype = torch .float64 )
319
+ bad_tensor [1 , 1 ] = torch .finfo (torch .float64 ).max
320
+ mock_torch_load .return_value = bad_tensor
321
+ with pytest .raises (Exception ) as e :
322
+ temp_prompt_cache .get ("bad_prompt" )
323
+ assert e .match ("torch.float64 to torch.float32" )
324
+ assert len (temp_prompt_cache ) == 0
325
+
326
+ @patch ("torch.load" )
327
+ @patch ("pathlib.Path.is_file" )
328
+ def test_prompt_with_nan (mock_is_file , mock_torch_load , temp_prompt_cache ):
329
+ mock_is_file .return_value = True
330
+ bad_tensor = torch .ones ((3 ,3 ), dtype = torch .float16 )
331
+ bad_tensor [1 , 1 ] = torch .nan
332
+ mock_torch_load .return_value = bad_tensor
333
+ with pytest .raises (Exception ):
334
+ temp_prompt_cache .get ("bad_prompt" )
335
+ assert len (temp_prompt_cache ) == 0
0 commit comments