Skip to content

Commit b37ec29

Browse files
committed
refactor
1 parent 58b8f64 commit b37ec29

File tree

5 files changed

+363
-314
lines changed

5 files changed

+363
-314
lines changed

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ onnx_diagnostic.torch_export_patches
88
eval/index
99
onnx_export_errors
1010
onnx_export_serialization
11+
onnx_export_serialization_impl
1112
patches/index
1213
patch_expressions
1314
patch_inputs
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl
3+
===================================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl
6+
:members:
7+
:no-undoc-members:

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,13 @@ def test_unet_2d_condition_output(self):
224224

225225
bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
226226
self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput")
227+
bo2 = torch_deepcopy([bo])
228+
self.assertIsInstance(bo2, list)
229+
self.assertEqual(
230+
"UNet2DConditionOutput(sample:T1s4x4x4)",
231+
self.string_type(bo, with_shape=True),
232+
)
233+
227234
with torch_export_patches():
228235
# internal function
229236
bo2 = torch_deepcopy([bo])

0 commit comments

Comments
 (0)