Skip to content

Commit d003bb6

Browse files
committed
add mb
1 parent 75e3dfe commit d003bb6

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,14 @@ def forward(self, x, y=None):
295295
model(x, y)
296296
stat = model_statistics(model)
297297
self.assertEqual(
298-
{"type": "Model", "n_modules": 1, "param_size": 4, "buffer_size": 4, "float32": 8},
298+
{
299+
"type": "Model",
300+
"n_modules": 1,
301+
"param_size": 4,
302+
"buffer_size": 4,
303+
"float32": 8,
304+
"size_mb": 0,
305+
},
299306
stat,
300307
)
301308

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,9 @@ def model_statistics(model: torch.nn.Module):
514514
res = dict(
515515
type=model.__class__.__name__,
516516
n_modules=n_subs,
517-
param_size=size,
517+
param_size=param_size,
518518
buffer_size=buffer_size,
519+
size_mb=(param_size + buffer_size) // 2**20,
519520
)
520521
res.update(sizes)
521522
return res

0 commit comments

Comments
 (0)