@@ -209,6 +209,41 @@ def forward(self, x, y=None):
209209 cst = torch .export .Dim .DYNAMIC
210210 self .assertEqual (({0 : cst , 1 : cst }, {1 : cst }), observer .infer_dynamic_shapes ())
211211
212+ def test_infer_arguments_optional (self ):
213+ class Model (torch .nn .Module ):
214+ def forward (self , x , y = None ):
215+ if y is None :
216+ return x
217+ return x - y
218+
219+ inputs = [
220+ (torch .randn ((5 , 6 )),),
221+ (torch .randn ((6 , 7 )), torch .randn ((1 , 7 ))),
222+ (torch .randn ((7 , 8 )), torch .randn ((1 , 8 ))),
223+ (torch .randn ((8 , 9 )), torch .randn ((1 , 9 ))),
224+ ]
225+
226+ model = Model ()
227+ expected = [model (* args ) for args in inputs ]
228+ observer = InputObserver ()
229+ with observer (model ):
230+ for args in inputs :
231+ model (* args )
232+ self .assertEqual (len (observer .info ), 3 )
233+ for i in range (3 ):
234+ self .assertEqual (len (observer .info .flat_outputs [i ]), 1 )
235+ torch .testing .assert_close (expected [i ], observer .info .flat_outputs [i ][0 ])
236+
237+ cst = torch .export .Dim .DYNAMIC
238+ self .assertEqual (({0 : cst , 1 : cst }, {1 : cst }), observer .infer_dynamic_shapes ())
239+ infer_args = observer .infer_arguments (0 )
240+ self .assertIsInstance (infer_args , tuple )
241+ self .assertEqual (len (infer_args ), 2 )
242+ self .assertIsInstance (infer_args [0 ], torch .Tensor )
243+ self .assertIsInstance (infer_args [1 ], torch .Tensor )
244+ self .assertEqual (infer_args [0 ].shape , (5 , 6 ))
245+ self .assertEqual (infer_args [1 ].shape , (1 , 0 ))
246+
212247 def test_io_captured_optional_kwargs (self ):
213248 class Model (torch .nn .Module ):
214249 def forward (self , x , y = None ):
@@ -589,11 +624,11 @@ def forward(self, x, y, z=None, w=None):
589624 cst = torch .export .Dim .DYNAMIC
590625 self .assertEqual (
591626 dict (x = {0 : cst , 1 : cst }, y = {1 : cst }, z = {0 : cst , 1 : cst }, w = {1 : cst }),
592- observer .infer_dynamic_shapes (add_batch_dimension_for = {0 , "z" }),
627+ observer .infer_dynamic_shapes (set_batch_dimension_for = {0 , "z" }),
593628 )
594629 self .assertEqual (
595630 dict (x = {0 : cst , 1 : cst }, y = {1 : cst }, z = {0 : cst , 1 : cst }, w = {1 : cst }),
596- observer .infer_dynamic_shapes (add_batch_dimension_for = {"x" , "z" }),
631+ observer .infer_dynamic_shapes (set_batch_dimension_for = {"x" , "z" }),
597632 )
598633
599634
0 commit comments