Skip to content

Commit f6a8e0f

Browse files
committed
refactor
1 parent cd872fa commit f6a8e0f

File tree

2 files changed

+296
-132
lines changed

2 files changed

+296
-132
lines changed

_unittests/ut_investigate/test_input_observer.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,41 @@ def forward(self, x, y=None):
209209
cst = torch.export.Dim.DYNAMIC
210210
self.assertEqual(({0: cst, 1: cst}, {1: cst}), observer.infer_dynamic_shapes())
211211

212+
def test_infer_arguments_optional(self):
213+
class Model(torch.nn.Module):
214+
def forward(self, x, y=None):
215+
if y is None:
216+
return x
217+
return x - y
218+
219+
inputs = [
220+
(torch.randn((5, 6)),),
221+
(torch.randn((6, 7)), torch.randn((1, 7))),
222+
(torch.randn((7, 8)), torch.randn((1, 8))),
223+
(torch.randn((8, 9)), torch.randn((1, 9))),
224+
]
225+
226+
model = Model()
227+
expected = [model(*args) for args in inputs]
228+
observer = InputObserver()
229+
with observer(model):
230+
for args in inputs:
231+
model(*args)
232+
self.assertEqual(len(observer.info), 3)
233+
for i in range(3):
234+
self.assertEqual(len(observer.info.flat_outputs[i]), 1)
235+
torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0])
236+
237+
cst = torch.export.Dim.DYNAMIC
238+
self.assertEqual(({0: cst, 1: cst}, {1: cst}), observer.infer_dynamic_shapes())
239+
infer_args = observer.infer_arguments(0)
240+
self.assertIsInstance(infer_args, tuple)
241+
self.assertEqual(len(infer_args), 2)
242+
self.assertIsInstance(infer_args[0], torch.Tensor)
243+
self.assertIsInstance(infer_args[1], torch.Tensor)
244+
self.assertEqual(infer_args[0].shape, (5, 6))
245+
self.assertEqual(infer_args[1].shape, (1, 0))
246+
212247
def test_io_captured_optional_kwargs(self):
213248
class Model(torch.nn.Module):
214249
def forward(self, x, y=None):
@@ -589,11 +624,11 @@ def forward(self, x, y, z=None, w=None):
589624
cst = torch.export.Dim.DYNAMIC
590625
self.assertEqual(
591626
dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}),
592-
observer.infer_dynamic_shapes(add_batch_dimension_for={0, "z"}),
627+
observer.infer_dynamic_shapes(set_batch_dimension_for={0, "z"}),
593628
)
594629
self.assertEqual(
595630
dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}),
596-
observer.infer_dynamic_shapes(add_batch_dimension_for={"x", "z"}),
631+
observer.infer_dynamic_shapes(set_batch_dimension_for={"x", "z"}),
597632
)
598633

599634

0 commit comments

Comments
 (0)