Skip to content

Commit 5864978

Browse files
committed
fix bug
1 parent 2662894 commit 5864978

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

_unittests/ut_helpers/test_mini_onnx_builder.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,48 @@ def test_mini_onnx_builder_transformers_sep(self):
151151
restored = create_input_tensors_from_onnx_model(model, sep="#")
152152
self.assertEqualAny(inputs, restored)
153153

154+
def test_specific_data(self):
155+
data = {
156+
("amain", 0, "I"): (
157+
(
158+
torch.rand((2, 16, 3, 448, 448), dtype=torch.float16),
159+
torch.rand((2, 16, 32, 32), dtype=torch.float16),
160+
torch.rand((2, 2)).to(torch.int64),
161+
),
162+
{},
163+
),
164+
}
165+
model = create_onnx_model_from_input_tensors(data)
166+
shapes = [
167+
tuple(d.dim_value for d in i.type.tensor_type.shape.dim)
168+
for i in model.graph.output
169+
]
170+
self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)])
171+
names = [i.name for i in model.graph.output]
172+
self.assertEqual(
173+
[
174+
"dict._((amain,0,I))___tuple_0___tuple_0___tensor",
175+
"dict._((amain,0,I))___tuple_0___tuple_1___tensor",
176+
"dict._((amain,0,I))___tuple_0___tuple_2.___tensor",
177+
"dict._((amain,0,I))___tuple_1.___dict.___empty",
178+
],
179+
names,
180+
)
181+
shapes = [tuple(i.dims) for i in model.graph.initializer]
182+
self.assertEqual(shapes, [(2, 16, 3, 448, 448), (2, 16, 32, 32), (2, 2), (0,)])
183+
names = [i.name for i in model.graph.initializer]
184+
self.assertEqual(
185+
[
186+
"t_dict._((amain,0,I))___tuple_0___tuple_0___tensor",
187+
"t_dict._((amain,0,I))___tuple_0___tuple_1___tensor",
188+
"t_dict._((amain,0,I))___tuple_0___tuple_2.___tensor",
189+
"t_dict._((amain,0,I))___tuple_1.___dict.___empty",
190+
],
191+
names,
192+
)
193+
restored = create_input_tensors_from_onnx_model(model)
194+
self.assertEqualAny(data, restored)
195+
154196

155197
if __name__ == "__main__":
156198
unittest.main(verbosity=2)

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def append_output_initializer(
139139
return
140140

141141
init_name = f"t_{name}"
142+
assert (
143+
init_name not in self.initializers_dict
144+
), f"name={init_name!r} already in {sorted(self.initializers_dict)}"
142145
self.initializers_dict[init_name] = tensor
143146
shape = tuple(map(int, tensor.shape))
144147
self.outputs.append(
@@ -324,21 +327,21 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator:
324327
for i, o in enumerate(obj):
325328
if i == len(obj) - 1:
326329
for p, oo in _flatten_iterator(o, sep):
327-
yield f"tuple.{sep}{p}", oo
330+
yield f"tuple_{i}.{sep}{p}", oo
328331
else:
329332
for p, oo in _flatten_iterator(o, sep):
330-
yield f"tuple{sep}{p}", oo
333+
yield f"tuple_{i}{sep}{p}", oo
331334
elif isinstance(obj, list):
332335
if not obj:
333336
yield f"list.{sep}empty", None
334337
else:
335338
for i, o in enumerate(obj):
336339
if i == len(obj) - 1:
337340
for p, oo in _flatten_iterator(o, sep):
338-
yield f"list.{sep}{p}", oo
341+
yield f"list_{i}.{sep}{p}", oo
339342
else:
340343
for p, oo in _flatten_iterator(o, sep):
341-
yield f"list{sep}{p}", oo
344+
yield f"list_{i}{sep}{p}", oo
342345
elif isinstance(obj, dict):
343346
if not obj:
344347
yield f"dict.{sep}empty", None

0 commit comments

Comments
 (0)