11import contextlib
22import inspect
3+ import os
34from collections .abc import Iterable
45from typing import Any , Callable , Dict , List , Optional , Tuple , Union
56import numpy as np
67import onnx
78import torch
8- from .helper import string_type
9+ from .helper import string_type , size_type
910from .cache_helper import (
1011 make_dynamic_cache ,
1112 make_encoder_decoder_cache ,
1617
1718
1819def _forward_ (
19- * args , _f = None , _fprint = string_type , _prefix = "" , _context = None , _storage = None , ** kwargs
20+ * args ,
21+ _f = None ,
22+ _fprint = string_type ,
23+ _prefix = "" ,
24+ _context = None ,
25+ _storage = None ,
26+ _storage_limit = 2 ** 27 ,
27+ _verbose = 0 ,
28+ ** kwargs ,
2029):
2130 assert _f is not None , "_f cannot be None"
2231 assert _context is not None , "_context cannot be None"
@@ -42,7 +51,20 @@ def _forward_(
4251 print (f"{ indent } -> { _fprint (res , ** kws )} " )
4352 print (f"{ indent } -{ _prefix } ." )
4453 if _storage is not None :
45- _storage [(* key , "O" )] = torch_deepcopy (res )
54+ size = torch_tensor_size (res )
55+ if size < _storage_limit :
56+ if _verbose :
57+ print (
58+ f"-- stores key={ key } , size { size // 2 ** 10 } Kb -- "
59+ f"{ string_type (res , with_shape = True )} "
60+ )
61+ _storage [(* key , "O" )] = torch_deepcopy (res )
62+ else :
63+ if _verbose :
64+ print (
65+ f"-- skips key={ key } , size { size // 2 ** 10 } Kb -- "
66+ f"{ string_type (res , with_shape = True )} "
67+ )
4668 _context ["iteration" ] += 1
4769 return res
4870
@@ -92,6 +114,8 @@ def steal_forward(
92114 fprint : Callable = string_type ,
93115 dump_file : Optional [str ] = None ,
94116 submodules : bool = False ,
117+ verbose : int = 0 ,
118+ storage_limit : int = 2 ** 27 ,
95119 ** kwargs ,
96120):
97121 """
@@ -110,6 +134,8 @@ def steal_forward(
110134 <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
111135 :param submodules: if True and model is a module, the list extended with all the submodules
112136 the module contains
137+ :param verbose: verbosity
138+ :param storage_limit: do not stored object bigger than this
113139
114140 The following examples shows how to steal and dump all the inputs / outputs
115141 for a module and its submodules, then restores them.
@@ -181,8 +207,16 @@ def forward(self, x, y):
181207 keep_model_forward [id (m )] = (m , m .forward )
182208 c = context .copy ()
183209 c ["class_name" ] = m .__class__ .__name__
184- m .forward = lambda * args , _f = m .forward , _fp = fprint , _c = c , _p = name , _s = storage , ** kws : _forward_ ( # noqa: E501
185- * args , _f = _f , _fprint = _fp , _context = _c , _prefix = _p , _storage = _s , ** kws
210+ m .forward = lambda * args , _f = m .forward , _fp = fprint , _c = c , _p = name , _s = storage , _v = verbose , _sl = storage_limit , ** kws : _forward_ ( # noqa: E501
211+ * args ,
212+ _f = _f ,
213+ _fprint = _fp ,
214+ _context = _c ,
215+ _prefix = _p ,
216+ _storage = _s ,
217+ _verbose = _v ,
218+ _storage_limit = _sl ,
219+ ** kws ,
186220 )
187221 try :
188222 yield
@@ -196,13 +230,21 @@ def forward(self, x, y):
196230 storage .update (_additional_stolen_objects )
197231 # We clear the cache.
198232 _additional_stolen_objects .clear ()
233+ if verbose :
234+ size = torch_tensor_size (storage )
235+ print (f"-- gather stored { len (storage )} objects, size={ size // 2 ** 20 } Mb" )
199236 proto = create_onnx_model_from_input_tensors (storage )
237+ if verbose :
238+ print ("-- dumps stored objects" )
200239 onnx .save (
201240 proto ,
202241 dump_file ,
203242 save_as_external_data = True ,
204243 all_tensors_to_one_file = True ,
244+ location = f"{ os .path .split (dump_file )[- 1 ]} .weight" ,
205245 )
246+ if verbose :
247+ print ("-- done dump stored objects" )
206248
207249
208250@contextlib .contextmanager
@@ -552,6 +594,37 @@ def torch_deepcopy(value: Any) -> Any:
552594 raise NotImplementedError (f"torch_deepcopy not implemented for type { type (value )} " )
553595
554596
597+ def torch_tensor_size (value : Any ) -> Any :
598+ """Returns the number of bytes stored in tensors."""
599+ if value is None :
600+ return 0
601+ if isinstance (value , (int , float , str )):
602+ return 0
603+ if isinstance (value , (tuple , list , set )):
604+ return sum (torch_tensor_size (v ) for v in value )
605+ if isinstance (value , dict ):
606+ return sum (torch_tensor_size (v ) for v in value .values ())
607+ if isinstance (value , np .ndarray ):
608+ return value .copy ()
609+ if hasattr (value , "clone" ):
610+ return value .numel () * size_type (value .dtype )
611+ if value .__class__ .__name__ in {"DynamicCache" , "SlidingWindowCache" }:
612+ return torch_tensor_size (value .key_cache ) + torch_tensor_size (value .value_cache )
613+ if value .__class__ .__name__ == "EncoderDecoderCache" :
614+ return torch_tensor_size (value .self_attention_cache ) + torch_tensor_size (
615+ value .cross_attention_cache
616+ )
617+ if value .__class__ .__name__ == "MambaCache" :
618+ return torch_tensor_size (value .conv_states ) + torch_tensor_size (value .ssm_states )
619+ if value .__class__ in torch .utils ._pytree .SUPPORTED_NODES :
620+ args , spec = torch .utils ._pytree .tree_flatten (value )
621+ return sum (torch_tensor_size (a ) for a in args )
622+
623+ # We should have a code using serialization, deserialization assuming a model
624+ # cannot be exported without them.
625+ raise NotImplementedError (f"torch_tensor_size not implemented for type { type (value )} " )
626+
627+
555628def model_statistics (model : torch .nn .Module ):
556629 """Returns statistics on a model in a dictionary."""
557630 n_subs = len (list (model .modules ()))
0 commit comments