@@ -45,34 +45,34 @@ def __init__(self,
45
45
) -> None :
46
46
self .prefix_id = prefix_id
47
47
self .prompt = prompt
48
- self .prompt_size_mb = PromptCacheNode ._get_prompt_size_mb (prompt )
48
+ self .prompt_virtual_tokens , self . prompt_size_mb = PromptCacheNode ._get_prompt_stats (prompt )
49
49
self .next = next
50
50
self .prev = prev
51
51
52
52
@staticmethod
53
- def _get_prompt_size_mb (prompt : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]) -> int :
54
- """Get the memory size of a prompt. Note that we round up to the nearest
53
+ def _get_prompt_stats (prompt : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]) -> Tuple [ int , int ] :
54
+ """Get the number of virtual tokens and memory size of a prompt. Note that we round up to the nearest
55
55
increment of 512.
56
56
57
57
Args:
58
58
prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
59
59
Prompt tuple/tensor we want to take the size of.
60
60
61
61
Return:
62
- Prompt size in Mb.
62
+ (prompt virtual token count, prompt size in MiB)
63
63
"""
64
64
# In some cases, we may have None, e.g., an encoder / decoder
65
65
# where we don't have prompts to inject for both components
66
66
if prompt is None :
67
- return 0
67
+ return 0 , 0
68
68
# We either have a Tensor or an iterable of tensors; if it's not
69
69
# a tensor, take the size of all contained tensor objects.
70
70
elif not isinstance (prompt , torch .Tensor ):
71
- return sum (map (PromptCacheNode ._get_prompt_size_mb , prompt ))
71
+ return tuple ( sum (x ) for x in zip ( * map (PromptCacheNode ._get_prompt_stats , prompt ) ))
72
72
# Otherwise it's a tensor; round up to nearest 512 increment & convert to mb.
73
73
# See: https://discuss.pytorch.org/t/how-to-know-the-memory-allocated-for-a-tensor-on-gpu/28537/15
74
74
raw_size = prompt .element_size () * prompt .nelement ()
75
- return (math .ceil (raw_size / 512 ) * 512 ) / (1024 ** 2 )
75
+ return prompt . shape [ 0 ], (math .ceil (raw_size / 512 ) * 512 ) / (1024 ** 2 )
76
76
77
77
78
78
class DoublyLinkedList :
@@ -299,6 +299,8 @@ def _add_prefix_id_to_cache(
299
299
new_cache_node = PromptCacheNode (prompt = prefix , prefix_id = prefix_id )
300
300
del_tensors = {}
301
301
302
+ new_prompt_virtual_tokens = new_cache_node .prompt_virtual_tokens
303
+ new_prompt_size_mb = new_cache_node .prompt_size_mb
302
304
with self .requires_lock :
303
305
# If we already have it, return the node in the cache.
304
306
# This will release the tensor we just loaded.
@@ -309,7 +311,7 @@ def _add_prefix_id_to_cache(
309
311
if new_cache_node .prompt_size_mb > PROMPT_CACHE_SIZE_MB :
310
312
raise ValueError (f"Prefix ID object { prefix_id } exceeds the allowed cache size" )
311
313
312
- while self .cache_size_mb + new_cache_node . prompt_size_mb > PROMPT_CACHE_SIZE_MB :
314
+ while self .cache_size_mb + new_prompt_size_mb > PROMPT_CACHE_SIZE_MB :
313
315
# Hold a reference to the set of things to be deallocated until we're out of the
314
316
# critical section; then, we can handle the cache keys in a thread safe way
315
317
# without deallocating our tensors in it.
@@ -322,10 +324,15 @@ def _add_prefix_id_to_cache(
322
324
self .cache_size_mb -= del_node .prompt_size_mb
323
325
self .cache_dll .add_node_as_head (new_cache_node )
324
326
self .cache_map [prefix_id ] = new_cache_node
325
- self .cache_size_mb += new_cache_node . prompt_size_mb
326
- logger . info ( f"Added prefix { prefix_id } to the prompt cache" )
327
+ self .cache_size_mb += new_prompt_size_mb
328
+ cache_size_mb = self . cache_size_mb
327
329
if del_tensors :
328
330
logger .info (f"Deleted prefixes { list (del_tensors .keys ())} from the prompt cache" )
331
+
332
+ logger .info (
333
+ f"Added prefix { prefix_id } to the prompt cache, has { new_prompt_virtual_tokens } virtual tokens"
334
+ f", size { new_prompt_size_mb :.3f} MiB, total cache size is now { cache_size_mb :.2f} MiB"
335
+ )
329
336
return new_cache_node
330
337
331
338
def _get_from_cache (self , prefix_id : str ) -> PromptCacheNode :
0 commit comments