Skip to content

Commit ad3fb90

Browse files
alex-jw-brooksnjhill
authored andcommitted
Tuned prompt cache eviction
Nothing super fancy here, and nothing is releasing the GIL - this mostly focuses on correctness more than anything else, but attempts to minimize the critical section time where possible by ensuring we aren't actually locking anywhere while loading/creating/deleting tensors, and instead just locking while updating the cache dict or memory counter.
1 parent 783bf91 commit ad3fb90

File tree

2 files changed

+552
-37
lines changed

2 files changed

+552
-37
lines changed

server/tests/test_prompt_cache.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""Tests for evaluating the prompt cache, with particular focus on making sure
2+
it does LRU eviction in a thread safe way correctly.
3+
"""
4+
import gc
5+
import pytest
6+
from unittest.mock import patch
7+
import torch
8+
from threading import Lock
9+
from text_generation_server import prompt_cache
10+
11+
if torch.cuda.is_available():
12+
DEVICE = "cuda"
13+
torch.set_default_device(DEVICE)
14+
else:
15+
DEVICE = None
16+
17+
@pytest.fixture()
18+
def temp_prompt_cache():
19+
"""Build an empty prompt cache that we can test with."""
20+
return prompt_cache.PrefixCache(
21+
device=DEVICE,
22+
dtype=torch.float32,
23+
max_length=256,
24+
encoder_decoder=False,
25+
decoder_start_tok_embedding=None
26+
)
27+
28+
### Tests for linked list operations
29+
## Adding new nodes to the list
30+
def test_single_node_list_add_as_head():
31+
"""Ensure that we can create a list with a single node correctly."""
32+
dll = prompt_cache.DoublyLinkedList()
33+
node = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
34+
dll.add_node_as_head(node)
35+
assert dll.head is node
36+
assert dll.tail is node
37+
assert dll.head.next is None
38+
assert dll.head.prev is None
39+
assert dll.tail.next is None
40+
assert dll.tail.prev is None
41+
42+
def test_multi_node_list_add_as_head():
43+
"""Ensure that we can create a list with a single node correctly."""
44+
dll = prompt_cache.DoublyLinkedList()
45+
node1 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
46+
node2 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="2")
47+
node3 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="3")
48+
dll.add_node_as_head(node1)
49+
dll.add_node_as_head(node2)
50+
dll.add_node_as_head(node3)
51+
assert dll.head is node3
52+
assert dll.tail is node1
53+
assert node3.prev is None
54+
assert node3.next is node2
55+
assert node2.prev is node3
56+
assert node2.next is node1
57+
assert node1.next is None
58+
assert node1.prev is node2
59+
60+
## Removing nodes from the list
61+
def test_remove_tail_from_list_with_one_node():
62+
"""Ensure that we can remove a node from a list with one entry."""
63+
dll = prompt_cache.DoublyLinkedList()
64+
node = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
65+
dll.add_node_as_head(node)
66+
popped_node = dll.pop_tail_node()
67+
assert dll.head is None
68+
assert dll.tail is None
69+
assert popped_node is node
70+
71+
def test_remove_tail_from_multi_node_list():
72+
"""Ensure we can correctly remove the tail from the DLL."""
73+
dll = prompt_cache.DoublyLinkedList()
74+
node1 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
75+
node2 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="2")
76+
dll.add_node_as_head(node1)
77+
dll.add_node_as_head(node2)
78+
assert dll.tail is node1
79+
popped_node = dll.pop_tail_node()
80+
assert popped_node is node1
81+
assert dll.head is dll.tail
82+
assert dll.head is node2
83+
assert node2.next is None
84+
assert node2.prev is None
85+
86+
## Moving things within the list
87+
def test_move_to_head_with_one_node():
88+
"""Ensure that moving a node from a list with one entry is a noop."""
89+
dll = prompt_cache.DoublyLinkedList()
90+
node = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
91+
dll.add_node_as_head(node)
92+
dll.move_node_to_head(node)
93+
assert dll.head is node
94+
assert dll.tail is node
95+
assert dll.head.next is None
96+
assert dll.head.prev is None
97+
assert dll.tail.next is None
98+
assert dll.tail.prev is None
99+
100+
def test_move_to_head_multi_node_list():
101+
"""Ensure that moving the head to the front of a multi node list is a noop."""
102+
dll = prompt_cache.DoublyLinkedList()
103+
node1 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
104+
node2 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="2")
105+
# 2 <-> 1
106+
dll.add_node_as_head(node1)
107+
dll.add_node_as_head(node2)
108+
# 2 <-> 1
109+
dll.move_node_to_head(node2)
110+
assert dll.head is node2
111+
assert dll.tail is node1
112+
assert node2.next is node1
113+
assert node2.prev is None
114+
assert node1.prev is node2
115+
assert node1.next is None
116+
117+
def test_move_to_head_from_tail_multi_node_list():
118+
"""Ensure that we can move the tail of a multinode DLL to the head correctly."""
119+
dll = prompt_cache.DoublyLinkedList()
120+
node1 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
121+
node2 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="2")
122+
# 2 <-> 1
123+
dll.add_node_as_head(node1)
124+
dll.add_node_as_head(node2)
125+
# 1 <-> 2
126+
dll.move_node_to_head(node1)
127+
assert dll.head is node1
128+
assert dll.tail is node2
129+
assert node1.next is node2
130+
assert node1.prev is None
131+
assert node2.prev is node1
132+
assert node2.next is None
133+
134+
def test_move_to_head_from_middle_multi_node_list():
135+
"""Ensure that we can move a node from the middle of a multinode DLL to the head correctly."""
136+
dll = prompt_cache.DoublyLinkedList()
137+
node1 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
138+
node2 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="2")
139+
node3 = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="3")
140+
# 3 <-> 2 <-> 1
141+
dll.add_node_as_head(node1)
142+
dll.add_node_as_head(node2)
143+
dll.add_node_as_head(node3)
144+
# 2 <-> 3 <-> 1
145+
dll.move_node_to_head(node2)
146+
assert dll.head is node2
147+
assert dll.tail is node1
148+
assert node2.next is node3
149+
assert node2.prev is None
150+
assert node3.prev is node2
151+
assert node3.next is node1
152+
assert node1.prev is node3
153+
assert node1.next is None
154+
155+
### Tests for thread lock manager
156+
def test_thread_lock_manager():
157+
"""Ensure that when we enter/exit a lock manager, we correctly lock/unlock."""
158+
lock = Lock()
159+
lock_manager = prompt_cache.ThreadLockManager(lock)
160+
assert not lock.locked()
161+
with lock_manager:
162+
assert lock.locked()
163+
assert not lock.locked()
164+
165+
### Tests for prompt cache node objects
166+
def test_prompt_cache_node_tensor():
167+
"""Verify that our tensor size estimation is correct for a single tensor prompt."""
168+
initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else None
169+
node = prompt_cache.PromptCacheNode(torch.ones((3, 3)), prefix_id="1")
170+
expected_memory_allocation = 512 # measured in bytes
171+
assert node.prompt_size_mb * (1024 ** 2) == expected_memory_allocation
172+
# Compare to the newly allocated cuda memory if cuda is available
173+
if initial_memory is not None:
174+
assert torch.cuda.memory_allocated() - initial_memory == expected_memory_allocation
175+
176+
def test_prompt_cache_node_tuple_all_tensors():
177+
"""Verify that our tensor size estimation is correct for a multitensor prompt."""
178+
initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else None
179+
node = prompt_cache.PromptCacheNode((torch.ones((3, 3)), torch.ones((3, 3)),), prefix_id="1")
180+
expected_memory_allocation = 1024 # measured in bytes
181+
assert node.prompt_size_mb * (1024 ** 2) == expected_memory_allocation
182+
# Compare to the newly allocated cuda memory if cuda is available
183+
if initial_memory is not None:
184+
assert torch.cuda.memory_allocated() - initial_memory == expected_memory_allocation
185+
186+
def test_prompt_cache_node_tuple_with_one_tensor():
187+
"""Ensure our tensor size estimation is correct if we have a None in our prompt tuple."""
188+
initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else None
189+
node = prompt_cache.PromptCacheNode((torch.ones((3, 3)), None,), prefix_id="1")
190+
expected_memory_allocation = 512 # measured in bytes
191+
assert node.prompt_size_mb * (1024 ** 2) == expected_memory_allocation
192+
# Compare to the newly allocated cuda memory if cuda is available
193+
if initial_memory is not None:
194+
assert torch.cuda.memory_allocated() - initial_memory == expected_memory_allocation
195+
196+
### End to end tests for prompt cache interactions
197+
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensors")
198+
def test_get_prompt_cache_no_eviction(mock_load_tensors, temp_prompt_cache):
199+
"""Ensure that if we hit a prompt cache hit, its timestamp updates."""
200+
mock_load_tensors.return_value = torch.ones((3, 3))
201+
dummy_prompt_id = "prompt1"
202+
# Prompt cache miss; add the dummy prompt ID to the cache
203+
t1 = temp_prompt_cache.get(dummy_prompt_id)
204+
assert len(temp_prompt_cache) == 1
205+
assert isinstance(t1, torch.Tensor)
206+
# Prompt cache hit; should retrieve the same tensor object
207+
t2 = temp_prompt_cache.get(dummy_prompt_id)
208+
assert len(temp_prompt_cache) == 1
209+
assert t1 is t2
210+
211+
@patch("text_generation_server.prompt_cache.PromptCacheNode._get_prompt_size_mb")
212+
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensors")
213+
def test_get_prompt_cache_with_eviction(mock_load_tensors, mock_get_prompt_size, temp_prompt_cache):
214+
"""Ensure that if we need to make space, we evicted the least recently used tensor."""
215+
mock_load_tensors.return_value = torch.ones((3, 3))
216+
mock_get_prompt_size.return_value = (prompt_cache.PROMPT_CACHE_SIZE_MB / 2) - 1
217+
temp_prompt_cache.get("prompt1")
218+
temp_prompt_cache.get("prompt2")
219+
# Evicts lru prompt ID (prompt1)
220+
temp_prompt_cache.get("prompt3")
221+
assert len(temp_prompt_cache) == 2
222+
assert set(temp_prompt_cache.keys()) == set(["prompt2", "prompt3"])
223+
# Access our oldest node, updating its timestamp
224+
temp_prompt_cache.get("prompt2")
225+
# Then ensure that adding a new prompt ID evicts prompt3 instead of prompt2
226+
temp_prompt_cache.get("prompt4")
227+
assert len(temp_prompt_cache) == 2
228+
assert set(temp_prompt_cache.keys()) == set(["prompt2", "prompt4"])
229+
230+
@patch("text_generation_server.prompt_cache.PromptCacheNode._get_prompt_size_mb")
231+
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensors")
232+
def test_get_prompt_cache_tensor_too_large(mock_load_tensors, mock_get_prompt_size, temp_prompt_cache):
233+
"""Ensure that an error is raised if a tensor greater than the cache size is found."""
234+
mock_load_tensors.return_value = torch.ones((3, 3))
235+
mock_get_prompt_size.return_value = prompt_cache.PROMPT_CACHE_SIZE_MB + 1
236+
with pytest.raises(ValueError):
237+
temp_prompt_cache.get("prompt1")
238+
239+
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensors")
240+
def test_clear_cache(mock_load_tensors, temp_prompt_cache):
241+
"""Ensure that we can clear the prompt cache correctly."""
242+
mock_load_tensors.return_value = torch.ones((3, 3))
243+
assert len(temp_prompt_cache) == 0
244+
temp_prompt_cache.get("prompt1")
245+
assert len(temp_prompt_cache) == 1
246+
temp_prompt_cache.clear()
247+
assert len(temp_prompt_cache) == 0
248+
249+
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensors")
250+
def test_get_cache_keys(mock_load_tensors, temp_prompt_cache):
251+
"""Ensure that we can grab the keys of the prompt cache correctly."""
252+
mock_load_tensors.return_value = torch.ones((3, 3))
253+
prompt_ids = set(["prompt1", "prompt2"])
254+
assert len(temp_prompt_cache) == 0
255+
for prompt_id in prompt_ids:
256+
temp_prompt_cache.get(prompt_id)
257+
assert set(temp_prompt_cache.keys()) == set(prompt_ids)
258+
259+
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensors")
260+
def test_get_cache_len(mock_load_tensors, temp_prompt_cache):
261+
"""Ensure that we can get the length of the prompt cache correctly."""
262+
mock_load_tensors.return_value = torch.ones((3, 3))
263+
assert len(temp_prompt_cache) == 0
264+
temp_prompt_cache.get("prompt1")
265+
temp_prompt_cache.get("prompt2")
266+
assert len(temp_prompt_cache) == 2

0 commit comments

Comments
 (0)