Skip to content

Commit d19245a

Browse files
committed
Include prompt prefix_id in per-request logs, log metadata when loading
Log the number of virtual tokens and size in memory when loading prompt prefixes.
1 parent 240a2da commit d19245a

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

router/src/grpc_server.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ impl GenerationService for GenerationServicer {
9090
skip_all,
9191
fields(
9292
input=?request.get_ref().requests.iter().map(|r| truncate(&r.text, 32)).collect::<Vec<Cow<'_,str>>>(),
93+
prefix_id=?request.get_ref().prefix_id,
9394
correlation_id=?request.metadata().get("x-correlation-id").map(|mv| mv.to_str().unwrap_or("<non-ascii>")).unwrap_or("<none>"),
9495
input_bytes=?request.get_ref().requests.iter().map(|r| r.text.len()).collect::<Vec<usize>>(),
9596
params=?request.get_ref().params,
@@ -171,6 +172,7 @@ impl GenerationService for GenerationServicer {
171172
skip_all,
172173
fields(
173174
input=?truncate(&request.get_ref().request.as_ref().map(|r| &*r.text).unwrap_or(""), 32),
175+
prefix_id=?request.get_ref().prefix_id,
174176
correlation_id=?request.metadata().get("x-correlation-id").map(|mv| mv.to_str().unwrap_or("<non-ascii>")).unwrap_or("<none>"),
175177
input_bytes=?request.get_ref().request.as_ref().map(|r| r.text.len()).unwrap_or(0),
176178
params=?request.get_ref().params,

server/text_generation_server/prompt_cache.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,34 +45,34 @@ def __init__(self,
4545
) -> None:
4646
self.prefix_id = prefix_id
4747
self.prompt = prompt
48-
self.prompt_size_mb = PromptCacheNode._get_prompt_size_mb(prompt)
48+
self.prompt_virtual_tokens, self.prompt_size_mb = PromptCacheNode._get_prompt_stats(prompt)
4949
self.next = next
5050
self.prev = prev
5151

5252
@staticmethod
53-
def _get_prompt_size_mb(prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> int:
54-
"""Get the memory size of a prompt. Note that we round up to the nearest
53+
def _get_prompt_stats(prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[int, int]:
54+
"""Get the number of virtual tokens and memory size of a prompt. Note that we round up to the nearest
5555
increment of 512.
5656
5757
Args:
5858
prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
5959
Prompt tuple/tensor we want to take the size of.
6060
6161
Return:
62-
Prompt size in Mb.
62+
(prompt virtual token count, prompt size in MiB)
6363
"""
6464
# In some cases, we may have None, e.g., an encoder / decoder
6565
# where we don't have prompts to inject for both components
6666
if prompt is None:
67-
return 0
67+
return 0, 0
6868
# We either have a Tensor or an iterable of tensors; if it's not
6969
# a tensor, take the size of all contained tensor objects.
7070
elif not isinstance(prompt, torch.Tensor):
71-
return sum(map(PromptCacheNode._get_prompt_size_mb, prompt))
71+
return tuple(sum(x) for x in zip(*map(PromptCacheNode._get_prompt_stats, prompt)))
7272
# Otherwise it's a tensor; round up to nearest 512 increment & convert to mb.
7373
# See: https://discuss.pytorch.org/t/how-to-know-the-memory-allocated-for-a-tensor-on-gpu/28537/15
7474
raw_size = prompt.element_size() * prompt.nelement()
75-
return (math.ceil(raw_size / 512) * 512) / (1024 ** 2)
75+
return prompt.shape[0], (math.ceil(raw_size / 512) * 512) / (1024 ** 2)
7676

7777

7878
class DoublyLinkedList:
@@ -299,6 +299,8 @@ def _add_prefix_id_to_cache(
299299
new_cache_node = PromptCacheNode(prompt=prefix, prefix_id=prefix_id)
300300
del_tensors = {}
301301

302+
new_prompt_virtual_tokens = new_cache_node.prompt_virtual_tokens
303+
new_prompt_size_mb = new_cache_node.prompt_size_mb
302304
with self.requires_lock:
303305
# If we already have it, return the node in the cache.
304306
# This will release the tensor we just loaded.
@@ -309,7 +311,7 @@ def _add_prefix_id_to_cache(
309311
if new_cache_node.prompt_size_mb > PROMPT_CACHE_SIZE_MB:
310312
raise ValueError(f"Prefix ID object {prefix_id} exceeds the allowed cache size")
311313

312-
while self.cache_size_mb + new_cache_node.prompt_size_mb > PROMPT_CACHE_SIZE_MB:
314+
while self.cache_size_mb + new_prompt_size_mb > PROMPT_CACHE_SIZE_MB:
313315
# Hold a reference to the set of things to be deallocated until we're out of the
314316
# critical section; then, we can handle the cache keys in a thread safe way
315317
# without deallocating our tensors in it.
@@ -322,10 +324,15 @@ def _add_prefix_id_to_cache(
322324
self.cache_size_mb -= del_node.prompt_size_mb
323325
self.cache_dll.add_node_as_head(new_cache_node)
324326
self.cache_map[prefix_id] = new_cache_node
325-
self.cache_size_mb += new_cache_node.prompt_size_mb
326-
logger.info(f"Added prefix {prefix_id} to the prompt cache")
327+
self.cache_size_mb += new_prompt_size_mb
328+
cache_size_mb = self.cache_size_mb
327329
if del_tensors:
328330
logger.info(f"Deleted prefixes {list(del_tensors.keys())} from the prompt cache")
331+
332+
logger.info(
333+
f"Added prefix {prefix_id} to the prompt cache, has {new_prompt_virtual_tokens} virtual tokens"
334+
f", size {new_prompt_size_mb:.3f}MiB, total cache size is now {cache_size_mb:.2f}MiB"
335+
)
329336
return new_cache_node
330337

331338
def _get_from_cache(self, prefix_id: str) -> PromptCacheNode:

0 commit comments

Comments
 (0)