File tree Expand file tree Collapse file tree 2 files changed +31
-2
lines changed
Expand file tree Collapse file tree 2 files changed +31
-2
lines changed Original file line number Diff line number Diff line change @@ -225,6 +225,31 @@ def _(x: torch.Tensor) -> torch.Tensor:
225225 example = torch .zeros ([10 , 20 ], device = device )
226226 torch .library .opcheck (f , args = [example ])
227227
228+ # https://github.com/pytorch/pytorch/issues/150472
229+ def test_single_element_tuple_output (self , device ):
230+ # Helper function to register id_tuple custom and the fake tensor implementation
231+ # so that Dynamo has the fake tensor implementation
232+ def get_id_tuple ():
233+ @torch .library .custom_op ("test::id_tuple" , mutates_args = [])
234+ def id_tuple (x : torch .Tensor ) -> Tuple [torch .Tensor ]:
235+ return (x .clone (),)
236+
237+ @id_tuple .register_fake
238+ def _ (
239+ x : torch .Tensor ,
240+ ) -> Tuple [torch .Tensor ]:
241+ return (x .clone (),)
242+
243+ return id_tuple
244+
245+ id_tuple = get_id_tuple ()
246+ x = torch .randn (3 , device = device )
247+ ret = id_tuple (x )
248+ # Check if ret is a tuple and has exactly one and the same element
249+ self .assertIsInstance (ret , tuple )
250+ self .assertEqual (len (ret ), 1 )
251+ self .assertEqual (x , ret [0 ])
252+
228253 def test_missing_abstract_impl (self , device ):
229254 lib = self .lib ()
230255 lib .define ("foo(Tensor x) -> Tensor" )
Original file line number Diff line number Diff line change @@ -77,7 +77,7 @@ def convert_type_string(annotation_type: str):
7777 )
7878
7979 def unstringify_types (
80- tys : tuple [Union [type [object ], str ], ...]
80+ tys : tuple [Union [type [object ], str ], ...],
8181 ) -> tuple [tuple [typing .Any , ...], bool ]:
8282 res = []
8383 changed = False
@@ -282,8 +282,12 @@ def parse_return(annotation, error_fn):
282282 f"Return has unsupported type { annotation } . "
283283 f"The valid types are: { SUPPORTED_RETURN_TYPES } ."
284284 )
285+ output_ty = ", " .join ([SUPPORTED_RETURN_TYPES [arg ] for arg in args ])
285286
286- return "(" + ", " .join ([SUPPORTED_RETURN_TYPES [arg ] for arg in args ]) + ")"
287+ # use (()) to represent tuple with single element
288+ if len (args ) == 1 :
289+ output_ty = "(" + output_ty + ")"
290+ return "(" + output_ty + ")"
287291
288292
289293SUPPORTED_PARAM_TYPES = get_supported_param_types ()
You can’t perform that action at this time.
0 commit comments