Skip to content

Commit 9c1e83d

Browse files
committed
first step for gemma3
1 parent 85ed339 commit 9c1e83d

File tree

7 files changed

+77
-4
lines changed

7 files changed

+77
-4
lines changed

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,39 @@ def forward(self, x, y):
181181
set(restored),
182182
)
183183

184+
@hide_stdout()
185+
def test_steal_forward_dump_file_steal_append_drop(self):
186+
class SubModel(torch.nn.Module):
187+
def forward(self, x):
188+
return x * x
189+
190+
class Model(torch.nn.Module):
191+
def __init__(self):
192+
super().__init__()
193+
self.s1 = SubModel()
194+
self.s2 = SubModel()
195+
196+
def forward(self, x, y):
197+
sx = self.s1(x)
198+
steal_append("sx", sx)
199+
return sx + self.s2(y)
200+
201+
inputs = dict(x=torch.rand(3, 4), y=torch.rand(3, 4))
202+
model = Model()
203+
dump_file = self.get_dump_file("test_steal_forward_dump_file_drop.onnx")
204+
with steal_forward(model, dump_file=dump_file, dump_drop={"x"}):
205+
model(**inputs)
206+
model(**inputs)
207+
self.assertExists(dump_file)
208+
restored = create_input_tensors_from_onnx_model(dump_file)
209+
self.assertEqual(
210+
{("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")},
211+
set(restored),
212+
)
213+
first = restored[("", 0, "I")]
214+
_a, kws = first
215+
self.assertNotIn("x", kws)
216+
184217
@hide_stdout()
185218
def test_steal_forward_submodules(self):
186219
class SubModel(torch.nn.Module):

_unittests/ut_tasks/test_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.tasks.data import get_data
4+
5+
6+
class TestTasks(ExtTestCase):
7+
def test_get_data(self):
8+
name = "dummies_imagetext2text_generation_gemma3.onnx"
9+
data = get_data(name)
10+
print(data)
11+
12+
13+
if __name__ == "__main__":
14+
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,7 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
878878
model,
879879
dump_file=self.get_dump_file("test_imagetext2text_generation_gemma3_4b_it.onnx"),
880880
dump_drop={"attention_mask", "past_key_values", "pixel_values"},
881+
save_as_external_data=False,
881882
):
882883
generated_ids = model.generate(
883884
**inputs, max_new_tokens=282, do_sample=False, cache_implementation="static"

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def steal_forward(
287287
submodules: bool = False,
288288
verbose: int = 0,
289289
storage_limit: int = 2**27,
290+
save_as_external_data: bool = True,
290291
**kwargs,
291292
):
292293
"""
@@ -305,6 +306,8 @@ def steal_forward(
305306
they can be restored with :func:`create_input_tensors_from_onnx_model
306307
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
307308
:param dump_drop: to drop some inputs too big (only if dump_file is specified)
309+
:param save_as_external_data: True by default, but maybe better to have everything
310+
in a single file if possible
308311
:param submodules: if True and model is a module, the list extended with all the submodules
309312
the module contains
310313
:param verbose: verbosity
@@ -414,8 +417,14 @@ def forward(self, x, y):
414417
size = torch_tensor_size(storage)
415418
print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
416419
if dump_drop:
417-
print(string_type(dump_drop))
418-
stop
420+
for k, v in storage.items():
421+
if k[-1] == "I":
422+
_args, kwargs = v
423+
ii = set(kwargs) & dump_drop
424+
if ii:
425+
for i in ii:
426+
print("---", i)
427+
del kwargs[i]
419428
proto = create_onnx_model_from_input_tensors(storage)
420429
if verbose:
421430
print("-- dumps stored objects")
@@ -425,7 +434,7 @@ def forward(self, x, y):
425434
onnx.save(
426435
proto,
427436
dump_file,
428-
save_as_external_data=True,
437+
save_as_external_data=save_as_external_data,
429438
all_tensors_to_one_file=True,
430439
location=location,
431440
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
3+
4+
def get_data(name: str):
5+
"""Returns data stored in this folder."""
6+
filename = os.path.join(os.path.dirname(__file__), name)
7+
assert os.path.exists(
8+
filename
9+
), f"Unable to find a file with {name!r}, looked for {filename!r}"
10+
11+
from ...helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
12+
13+
return create_input_tensors_from_onnx_model(filename)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
[tool.setuptools.package-data]
2+
onnx_diagnostic = ["tasks/data/*.onnx"]
3+
14
[tool.black]
25
line-length = 95
36
extend-exclude = '''.*clones.*'''

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
here = os.path.dirname(__file__)
1111
if here == "":
1212
here = "."
13-
package_data = {"onnx_diagnostic.validation": ["*.css", "*.js"]}
13+
package_data = {"onnx_diagnostic.tasks.data": ["*.onnx"]}
1414

1515
try:
1616
with open(os.path.join(here, "requirements.txt"), "r") as f:

0 commit comments

Comments
 (0)