Skip to content

Commit 240a2da

Browse files
tjohnson31415njhill
authored andcommitted
fix: send tensor to device before concat with decoder prefix
If the model is on the GPU decoder_start_tok_embedding will be a Tensor on the GPU, so the decoder prefix needs to be sent to the GPU before the concatenation. Signed-off-by: Travis Johnson [email protected]
1 parent 2342f16 commit 240a2da

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

server/tests/test_prompt_cache.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@
1414
else:
1515
DEVICE = None
1616

17+
@pytest.fixture()
18+
def temp_prompt_cache_enc_dec_meta():
19+
"""Build an empty prompt cache for an encoder-decoder model with the 'meta'
20+
device."""
21+
dtype = torch.float16
22+
return prompt_cache.PrefixCache(
23+
device='meta',
24+
dtype=dtype,
25+
max_length=16,
26+
encoder_decoder=True,
27+
decoder_start_tok_embedding=torch.rand((1, 8), dtype=dtype, device='meta')
28+
)
29+
1730
@pytest.fixture()
1831
def temp_prompt_cache():
1932
"""Build an empty prompt cache that we can test with."""
@@ -165,6 +178,7 @@ def test_thread_lock_manager():
165178
### Tests for prompt cache node objects
166179
def test_prompt_cache_node_tensor():
167180
"""Verify that our tensor size estimation is correct for a single tensor prompt."""
181+
gc.collect()
168182
initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else None
169183
node = prompt_cache.PromptCacheNode(torch.ones((3, 3)), prefix_id="1")
170184
expected_memory_allocation = 512 # measured in bytes
@@ -175,6 +189,7 @@ def test_prompt_cache_node_tensor():
175189

176190
def test_prompt_cache_node_tuple_all_tensors():
177191
"""Verify that our tensor size estimation is correct for a multitensor prompt."""
192+
gc.collect()
178193
initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else None
179194
node = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
180195
expected_memory_allocation = 1024 # measured in bytes
@@ -185,6 +200,7 @@ def test_prompt_cache_node_tuple_all_tensors():
185200

186201
def test_prompt_cache_node_tuple_with_one_tensor():
187202
"""Ensure our tensor size estimation is correct if we have a None in our prompt tuple."""
203+
gc.collect()
188204
initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else None
189205
node = prompt_cache.PromptCacheNode((torch.ones((3, 3)), None,), prefix_id="1")
190206
expected_memory_allocation = 512 # measured in bytes
@@ -265,6 +281,15 @@ def test_get_cache_len(mock_load_tensors, temp_prompt_cache):
265281
temp_prompt_cache.get("prompt2")
266282
assert len(temp_prompt_cache) == 2
267283

284+
### Test code paths for encoder decoder model
285+
# TODO: add more tests here!
286+
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensor")
287+
def test_prompt_model_device_diff(mock_load, temp_prompt_cache_enc_dec_meta):
288+
# create prefix tensor on CPU which should be converted to the 'meta' device
289+
# before the decoder_start_tok_embedding is added to it
290+
mock_load.return_value = torch.ones((3,8), device='cpu')
291+
temp_prompt_cache_enc_dec_meta.get("bad_prompt")
292+
268293
### Test cases for invalid prompts
269294
@patch("pathlib.Path.is_file")
270295
def test_prompt_not_exist(mock_is_file, temp_prompt_cache):

server/text_generation_server/prompt_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def _load_embedding_tensors(self, prefix_id: str) -> Union[torch.Tensor, Tuple[t
216216
if encoder_prefix is None:
217217
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
218218
else:
219+
decoder_prefix = decoder_prefix.to(self.device, non_blocking=True)
219220
# TODO confirm this cat is correct
220221
decoder_prefix = torch.cat((decoder_prefix, self.decoder_start_tok_embedding))
221-
decoder_prefix = decoder_prefix.to(self.device, non_blocking=True)
222222
if encoder_prefix is not None:
223223
encoder_prefix = encoder_prefix.to(self.device, non_blocking=True)
224224
prefix = encoder_prefix, decoder_prefix

0 commit comments

Comments
 (0)