26
26
from .hasher import MultiModalHasher
27
27
from .inputs import (MultiModalDataDict , MultiModalEncDecInputs ,
28
28
MultiModalFieldConfig , MultiModalInputs , MultiModalKwargs ,
29
- MultiModalKwargsItem , PlaceholderRange )
29
+ MultiModalKwargsItem , NestedTensors , PlaceholderRange )
30
30
from .parse import (DictEmbeddingItems , EmbeddingItems , MultiModalDataItems ,
31
31
MultiModalDataParser )
32
32
@@ -853,33 +853,62 @@ class ProcessingCache:
853
853
854
854
@staticmethod
855
855
def get_lru_cache (
856
- capacity_gb : int ,
856
+ capacity_gb : float ,
857
857
value_type : type [_V ],
858
+ * ,
859
+ debug : bool = False ,
858
860
) -> LRUCache [str , _V ]:
859
861
860
- def get_size (leaf : object ) -> int :
862
+ def get_leaf_size (leaf : object ) -> int :
863
+ # MultiModalKwargs is not a subclass of dict
864
+ if isinstance (leaf , MultiModalKwargs ):
865
+ return get_item_size (leaf .data )
866
+
867
+ # MultiModalKwargsItem is not a subclass of dict
868
+ if isinstance (leaf , MultiModalKwargsItem ):
869
+ leaf_data = {k : v .data for k , v in leaf .items ()}
870
+ return get_item_size (leaf_data )
871
+
872
+ # sys.getsizeof doesn't work for tensors
861
873
if isinstance (leaf , torch .Tensor ):
862
- return leaf .nbytes # sys.getsizeof doesn't work for tensors
874
+ return leaf .nbytes
863
875
864
876
return sys .getsizeof (leaf )
865
877
866
- return LRUCache [str , _V ](
867
- GiB_bytes * capacity_gb ,
868
- getsizeof = lambda x : json_reduce_leaves (
878
+ def get_item_size (
879
+ value : Union [MultiModalKwargs , MultiModalKwargsItem ,
880
+ Mapping [str , NestedTensors ]]
881
+ ) -> int :
882
+ size = json_reduce_leaves (
869
883
lambda a , b : a + b ,
870
- json_map_leaves (get_size , x ),
871
- ),
872
- )
884
+ json_map_leaves (get_leaf_size , value ),
885
+ )
886
+
887
+ if debug :
888
+ logger .debug ("Calculated size of %s to be %.2f GiB" ,
889
+ type (value ), size / GiB_bytes )
873
890
874
- def __init__ (self , capacity_gb : int ) -> None :
891
+ return size
892
+
893
+ return LRUCache (GiB_bytes * capacity_gb , getsizeof = get_item_size )
894
+
895
+ def __init__ (
896
+ self ,
897
+ capacity_gb : float ,
898
+ * ,
899
+ debug_cache_hit_ratio_steps : Optional [int ] = None ,
900
+ ) -> None :
875
901
super ().__init__ ()
876
902
877
- # DEBUG: Set to None to disable
878
- self .debug_cache_hit_ratio_steps : Optional [int ] = None
903
+ self .debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
879
904
self .debug_cache_hits = 0
880
905
self .debug_cache_total = 0
881
906
882
- self ._cache = self .get_lru_cache (capacity_gb , MultiModalKwargsItem )
907
+ self ._cache = self .get_lru_cache (
908
+ capacity_gb ,
909
+ MultiModalKwargsItem ,
910
+ debug = bool (debug_cache_hit_ratio_steps ),
911
+ )
883
912
884
913
def _maybe_log_cache_stats (self ) -> None :
885
914
steps = self .debug_cache_hit_ratio_steps
@@ -890,6 +919,9 @@ def _maybe_log_cache_stats(self) -> None:
890
919
if total > 0 and total % steps == 0 :
891
920
logger .debug ("ProcessingCache: hit_ratio = %.2f" ,
892
921
self .debug_cache_hits / total )
922
+ logger .debug ("ProcessingCache: size = %.2f / %.2f GiB" ,
923
+ self ._cache .currsize / GiB_bytes ,
924
+ self ._cache .maxsize / GiB_bytes )
893
925
894
926
def get (
895
927
self ,
0 commit comments