Skip to content

Commit 8760e20

Browse files
committed
Add verbosity
1 parent 08b1cdf commit 8760e20

File tree

2 files changed

+85
-6
lines changed

2 files changed

+85
-6
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
replace_string_by_dynamic,
1616
to_any,
1717
torch_deepcopy,
18+
torch_tensor_size,
1819
)
1920
from onnx_diagnostic.helpers.cache_helper import (
2021
make_dynamic_cache,
@@ -204,7 +205,7 @@ def forward(self, x, y):
204205
else:
205206
print("output", k, v)
206207
print(string_type(restored, with_shape=True))
207-
l1, l2 = 182, 191
208+
l1, l2 = 183, 192
208209
self.assertEqual(
209210
[
210211
(f"-Model-{l2}", 0, "I"),
@@ -264,6 +265,7 @@ def test_torch_deepcopy_cache_dce(self):
264265
c1.key_cache[0] += 1000
265266
hash2 = string_type(at, with_shape=True, with_min_max=True)
266267
self.assertEqual(hash1, hash2)
268+
self.assertGreater(torch_tensor_size(cc), 1)
267269

268270
def test_torch_deepcopy_mamba_cache(self):
269271
cache = make_mamba_cache(
@@ -280,6 +282,7 @@ def test_torch_deepcopy_mamba_cache(self):
280282
cache.conv_states[0] += 1000
281283
hash2 = string_type(at, with_shape=True, with_min_max=True)
282284
self.assertEqual(hash1, hash2)
285+
self.assertGreater(torch_tensor_size(cache), 1)
283286

284287
def test_torch_deepcopy_base_model_outputs(self):
285288
bo = transformers.modeling_outputs.BaseModelOutput(
@@ -292,6 +295,7 @@ def test_torch_deepcopy_base_model_outputs(self):
292295
bo.last_hidden_state[0] += 1000
293296
hash2 = string_type(at, with_shape=True, with_min_max=True)
294297
self.assertEqual(hash1, hash2)
298+
self.assertGreater(torch_tensor_size(bo), 1)
295299

296300
def test_torch_deepcopy_sliding_windon_cache(self):
297301
cache = make_sliding_window_cache(
@@ -308,9 +312,11 @@ def test_torch_deepcopy_sliding_windon_cache(self):
308312
cache.key_cache[0] += 1000
309313
hash2 = string_type(at, with_shape=True, with_min_max=True)
310314
self.assertEqual(hash1, hash2)
315+
self.assertGreater(torch_tensor_size(cache), 1)
311316

312317
def test_torch_deepcopy_none(self):
313318
self.assertEmpty(torch_deepcopy(None))
319+
self.assertEqual(torch_tensor_size(None), 0)
314320

315321
def test_model_statistics(self):
316322
class Model(torch.nn.Module):

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import contextlib
22
import inspect
3+
import os
34
from collections.abc import Iterable
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
56
import numpy as np
67
import onnx
78
import torch
8-
from .helper import string_type
9+
from .helper import string_type, size_type
910
from .cache_helper import (
1011
make_dynamic_cache,
1112
make_encoder_decoder_cache,
@@ -16,7 +17,15 @@
1617

1718

1819
def _forward_(
19-
*args, _f=None, _fprint=string_type, _prefix="", _context=None, _storage=None, **kwargs
20+
*args,
21+
_f=None,
22+
_fprint=string_type,
23+
_prefix="",
24+
_context=None,
25+
_storage=None,
26+
_storage_limit=2**27,
27+
_verbose=0,
28+
**kwargs,
2029
):
2130
assert _f is not None, "_f cannot be None"
2231
assert _context is not None, "_context cannot be None"
@@ -42,7 +51,20 @@ def _forward_(
4251
print(f"{indent} -> {_fprint(res, **kws)}")
4352
print(f"{indent}-{_prefix}.")
4453
if _storage is not None:
45-
_storage[(*key, "O")] = torch_deepcopy(res)
54+
size = torch_tensor_size(res)
55+
if size < _storage_limit:
56+
if _verbose:
57+
print(
58+
f"-- stores key={key}, size {size // 2**10}Kb -- "
59+
f"{string_type(res, with_shape=True)}"
60+
)
61+
_storage[(*key, "O")] = torch_deepcopy(res)
62+
else:
63+
if _verbose:
64+
print(
65+
f"-- skips key={key}, size {size // 2**10}Kb -- "
66+
f"{string_type(res, with_shape=True)}"
67+
)
4668
_context["iteration"] += 1
4769
return res
4870

@@ -92,6 +114,8 @@ def steal_forward(
92114
fprint: Callable = string_type,
93115
dump_file: Optional[str] = None,
94116
submodules: bool = False,
117+
verbose: int = 0,
118+
storage_limit: int = 2**27,
95119
**kwargs,
96120
):
97121
"""
@@ -110,6 +134,8 @@ def steal_forward(
110134
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
111135
:param submodules: if True and model is a module, the list extended with all the submodules
112136
the module contains
137+
:param verbose: verbosity
138+
:param storage_limit: do not stored object bigger than this
113139
114140
The following examples shows how to steal and dump all the inputs / outputs
115141
for a module and its submodules, then restores them.
@@ -181,8 +207,16 @@ def forward(self, x, y):
181207
keep_model_forward[id(m)] = (m, m.forward)
182208
c = context.copy()
183209
c["class_name"] = m.__class__.__name__
184-
m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, **kws: _forward_( # noqa: E501
185-
*args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, _storage=_s, **kws
210+
m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, _v=verbose, _sl=storage_limit, **kws: _forward_( # noqa: E501
211+
*args,
212+
_f=_f,
213+
_fprint=_fp,
214+
_context=_c,
215+
_prefix=_p,
216+
_storage=_s,
217+
_verbose=_v,
218+
_storage_limit=_sl,
219+
**kws,
186220
)
187221
try:
188222
yield
@@ -196,13 +230,21 @@ def forward(self, x, y):
196230
storage.update(_additional_stolen_objects)
197231
# We clear the cache.
198232
_additional_stolen_objects.clear()
233+
if verbose:
234+
size = torch_tensor_size(storage)
235+
print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
199236
proto = create_onnx_model_from_input_tensors(storage)
237+
if verbose:
238+
print("-- dumps stored objects")
200239
onnx.save(
201240
proto,
202241
dump_file,
203242
save_as_external_data=True,
204243
all_tensors_to_one_file=True,
244+
location=f"{os.path.split(dump_file)[-1]}.weight",
205245
)
246+
if verbose:
247+
print("-- done dump stored objects")
206248

207249

208250
@contextlib.contextmanager
@@ -552,6 +594,37 @@ def torch_deepcopy(value: Any) -> Any:
552594
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
553595

554596

597+
def torch_tensor_size(value: Any) -> Any:
598+
"""Returns the number of bytes stored in tensors."""
599+
if value is None:
600+
return 0
601+
if isinstance(value, (int, float, str)):
602+
return 0
603+
if isinstance(value, (tuple, list, set)):
604+
return sum(torch_tensor_size(v) for v in value)
605+
if isinstance(value, dict):
606+
return sum(torch_tensor_size(v) for v in value.values())
607+
if isinstance(value, np.ndarray):
608+
return value.copy()
609+
if hasattr(value, "clone"):
610+
return value.numel() * size_type(value.dtype)
611+
if value.__class__.__name__ in {"DynamicCache", "SlidingWindowCache"}:
612+
return torch_tensor_size(value.key_cache) + torch_tensor_size(value.value_cache)
613+
if value.__class__.__name__ == "EncoderDecoderCache":
614+
return torch_tensor_size(value.self_attention_cache) + torch_tensor_size(
615+
value.cross_attention_cache
616+
)
617+
if value.__class__.__name__ == "MambaCache":
618+
return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
619+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
620+
args, spec = torch.utils._pytree.tree_flatten(value)
621+
return sum(torch_tensor_size(a) for a in args)
622+
623+
# We should have a code using serialization, deserialization assuming a model
624+
# cannot be exported without them.
625+
raise NotImplementedError(f"torch_tensor_size not implemented for type {type(value)}")
626+
627+
555628
def model_statistics(model: torch.nn.Module):
556629
"""Returns statistics on a model in a dictionary."""
557630
n_subs = len(list(model.modules()))

0 commit comments

Comments
 (0)