Skip to content

Commit 75e3dfe

Browse files
committed
add model_statistics
1 parent 374538b commit 75e3dfe

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

_unittests/ut_helpers/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_print_pretty_onnx(self):
127127
)
128128
self.print_onnx(proto)
129129
self.print_model(proto)
130-
self.dump_onnx("test_print_pretty_onnx", proto)
130+
self.dump_onnx("test_print_pretty.onnx", proto)
131131
self.check_ort(proto)
132132
self.assertNotEmpty(proto)
133133
self.assertEmpty(None)

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
dummy_llm,
1010
to_numpy,
1111
is_torchdynamo_exporting,
12+
model_statistics,
1213
steal_forward,
1314
replace_string_by_dynamic,
1415
to_any,
@@ -279,6 +280,25 @@ def test_torch_deepcopy_sliding_windon_cache(self):
279280
def test_torch_deepcopy_none(self):
280281
self.assertEmpty(torch_deepcopy(None))
281282

283+
def test_model_statistics(self):
284+
class Model(torch.nn.Module):
285+
def __init__(self):
286+
super().__init__()
287+
self.p1 = torch.nn.Parameter(torch.tensor([1], dtype=torch.float32))
288+
self.b1 = torch.nn.Buffer(torch.tensor([1], dtype=torch.float32))
289+
290+
def forward(self, x, y=None):
291+
return x + y + self.p1 + self.b1
292+
293+
model = Model()
294+
x, y = torch.rand((3, 4)), torch.rand((3, 4))
295+
model(x, y)
296+
stat = model_statistics(model)
297+
self.assertEqual(
298+
{"type": "Model", "n_modules": 1, "param_size": 4, "buffer_size": 4, "float32": 8},
299+
stat,
300+
)
301+
282302

283303
if __name__ == "__main__":
284304
unittest.main(verbosity=2)

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,3 +487,35 @@ def torch_deepcopy(value: Any) -> Any:
487487
# We should have a code using serialization, deserialization assuming a model
488488
# cannot be exported without them.
489489
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
490+
491+
492+
def model_statistics(model: torch.nn.Module):
493+
"""Returns statistics on a model in a dictionary."""
494+
n_subs = len(list(model.modules()))
495+
sizes = {}
496+
param_size = 0
497+
for param in model.parameters():
498+
size = param.nelement() * param.element_size()
499+
param_size += size
500+
name = str(param.dtype).replace("torch.", "")
501+
if name not in sizes:
502+
sizes[name] = 0
503+
sizes[name] += size
504+
505+
buffer_size = 0
506+
for buffer in model.buffers():
507+
size = buffer.nelement() * buffer.element_size()
508+
buffer_size += size
509+
name = str(buffer.dtype).replace("torch.", "")
510+
if name not in sizes:
511+
sizes[name] = 0
512+
sizes[name] += size
513+
514+
res = dict(
515+
type=model.__class__.__name__,
516+
n_modules=n_subs,
517+
param_size=size,
518+
buffer_size=buffer_size,
519+
)
520+
res.update(sizes)
521+
return res

0 commit comments

Comments
 (0)