14
14
else :
15
15
DEVICE = None
16
16
17
+ @pytest .fixture ()
18
+ def temp_prompt_cache_enc_dec_meta ():
19
+ """Build an empty prompt cache for an encoder-decoder model with the 'meta'
20
+ device."""
21
+ dtype = torch .float16
22
+ return prompt_cache .PrefixCache (
23
+ device = 'meta' ,
24
+ dtype = dtype ,
25
+ max_length = 16 ,
26
+ encoder_decoder = True ,
27
+ decoder_start_tok_embedding = torch .rand ((1 , 8 ), dtype = dtype , device = 'meta' )
28
+ )
29
+
17
30
@pytest .fixture ()
18
31
def temp_prompt_cache ():
19
32
"""Build an empty prompt cache that we can test with."""
@@ -165,6 +178,7 @@ def test_thread_lock_manager():
165
178
### Tests for prompt cache node objects
166
179
def test_prompt_cache_node_tensor ():
167
180
"""Verify that our tensor size estimation is correct for a single tensor prompt."""
181
+ gc .collect ()
168
182
initial_memory = torch .cuda .memory_allocated () if torch .cuda .is_available () else None
169
183
node = prompt_cache .PromptCacheNode (torch .ones ((3 , 3 )), prefix_id = "1" )
170
184
expected_memory_allocation = 512 # measured in bytes
@@ -175,6 +189,7 @@ def test_prompt_cache_node_tensor():
175
189
176
190
def test_prompt_cache_node_tuple_all_tensors ():
177
191
"""Verify that our tensor size estimation is correct for a multitensor prompt."""
192
+ gc .collect ()
178
193
initial_memory = torch .cuda .memory_allocated () if torch .cuda .is_available () else None
179
194
node = prompt_cache .PromptCacheNode ((torch .ones ((3 , 3 )), torch .ones ((3 , 3 )),), prefix_id = "1" )
180
195
expected_memory_allocation = 1024 # measured in bytes
@@ -185,6 +200,7 @@ def test_prompt_cache_node_tuple_all_tensors():
185
200
186
201
def test_prompt_cache_node_tuple_with_one_tensor ():
187
202
"""Ensure our tensor size estimation is correct if we have a None in our prompt tuple."""
203
+ gc .collect ()
188
204
initial_memory = torch .cuda .memory_allocated () if torch .cuda .is_available () else None
189
205
node = prompt_cache .PromptCacheNode ((torch .ones ((3 , 3 )), None ,), prefix_id = "1" )
190
206
expected_memory_allocation = 512 # measured in bytes
@@ -265,6 +281,15 @@ def test_get_cache_len(mock_load_tensors, temp_prompt_cache):
265
281
temp_prompt_cache .get ("prompt2" )
266
282
assert len (temp_prompt_cache ) == 2
267
283
284
+ ### Test code paths for encoder decoder model
285
+ # TODO: add more tests here!
286
+ @patch ("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensor" )
287
+ def test_prompt_model_device_diff (mock_load , temp_prompt_cache_enc_dec_meta ):
288
+ # create prefix tensor on CPU which should be converted to the 'meta' device
289
+ # before the decoder_start_tok_embedding is added to it
290
+ mock_load .return_value = torch .ones ((3 ,8 ), device = 'cpu' )
291
+ temp_prompt_cache_enc_dec_meta .get ("bad_prompt" )
292
+
268
293
### Test cases for invalid prompts
269
294
@patch ("pathlib.Path.is_file" )
270
295
def test_prompt_not_exist (mock_is_file , temp_prompt_cache ):
0 commit comments