Skip to content

Commit 110ae0f

Browse files
azahed98pytorchmergebot
authored andcommitted
Custom Op handle 1-element tuples (pytorch#155447)
Fixes pytorch#150472 Modification of [PR 151408](pytorch#151408). This PR modifies the return parsing in `infer_schema` to handle the case of a Tuple with a single element. Pull Request resolved: pytorch#155447 Approved by: https://github.com/bdhirsh, https://github.com/zou3519
1 parent a2b0b26 commit 110ae0f

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

test/test_custom_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,31 @@ def _(x: torch.Tensor) -> torch.Tensor:
225225
example = torch.zeros([10, 20], device=device)
226226
torch.library.opcheck(f, args=[example])
227227

228+
# https://github.com/pytorch/pytorch/issues/150472
229+
def test_single_element_tuple_output(self, device):
230+
# Helper function to register id_tuple custom and the fake tensor implementation
231+
# so that Dynamo has the fake tensor implementation
232+
def get_id_tuple():
233+
@torch.library.custom_op("test::id_tuple", mutates_args=[])
234+
def id_tuple(x: torch.Tensor) -> Tuple[torch.Tensor]:
235+
return (x.clone(),)
236+
237+
@id_tuple.register_fake
238+
def _(
239+
x: torch.Tensor,
240+
) -> Tuple[torch.Tensor]:
241+
return (x.clone(),)
242+
243+
return id_tuple
244+
245+
id_tuple = get_id_tuple()
246+
x = torch.randn(3, device=device)
247+
ret = id_tuple(x)
248+
# Check if ret is a tuple and has exactly one and the same element
249+
self.assertIsInstance(ret, tuple)
250+
self.assertEqual(len(ret), 1)
251+
self.assertEqual(x, ret[0])
252+
228253
def test_missing_abstract_impl(self, device):
229254
lib = self.lib()
230255
lib.define("foo(Tensor x) -> Tensor")

torch/_library/infer_schema.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def convert_type_string(annotation_type: str):
7777
)
7878

7979
def unstringify_types(
80-
tys: tuple[Union[type[object], str], ...]
80+
tys: tuple[Union[type[object], str], ...],
8181
) -> tuple[tuple[typing.Any, ...], bool]:
8282
res = []
8383
changed = False
@@ -282,8 +282,12 @@ def parse_return(annotation, error_fn):
282282
f"Return has unsupported type {annotation}. "
283283
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
284284
)
285+
output_ty = ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args])
285286

286-
return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
287+
# use (()) to represent tuple with single element
288+
if len(args) == 1:
289+
output_ty = "(" + output_ty + ")"
290+
return "(" + output_ty + ")"
287291

288292

289293
SUPPORTED_PARAM_TYPES = get_supported_param_types()

0 commit comments

Comments
 (0)