Skip to content

Commit 1e20745

Browse files
yiming0416pytorchmergebot
authored andcommitted
[ez][AOTI] Fix index offset for Optional Tensor Return (pytorch#155073)
Summary: As title. See added test for more context. Test Plan: buck2 run mode/dev-nosan caffe2/test/inductor:test_aot_inductor_custom_ops -- -r test_fn_with_optional_tensor_output_2 Rollback Plan: Differential Revision: D75900658 Pull Request resolved: pytorch#155073 Approved by: https://github.com/angelayi
1 parent d2bfd97 commit 1e20745

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

test/inductor/custom_ops.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>> fn_with_optiona
2727
return {t3, t4, t5};
2828
}
2929

30+
std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>> fn_with_optional_tensor_output_2_impl(Tensor t1, Tensor t2) {
31+
Tensor t3 = t1 + t2;
32+
Tensor t4;
33+
Tensor t5 = t1 - t2;
34+
return {t3, t4, t5};
35+
}
36+
37+
std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>> fn_with_optional_tensor_output_2_meta(Tensor t1, Tensor t2) {
38+
Tensor t3 = t1.clone();
39+
Tensor t4;
40+
Tensor t5 = t1.clone();
41+
return {t3, t4, t5};
42+
}
43+
44+
45+
3046
Tensor fn_with_all_inputs_impl(
3147
const Tensor& tensor,
3248
const c10::List<Tensor>& tensors,
@@ -364,6 +380,7 @@ extern "C" {
364380
TORCH_LIBRARY(aoti_custom_ops, m) {
365381
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
366382
m.def("fn_with_optional_tensor_output(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
383+
m.def("fn_with_optional_tensor_output_2(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
367384
m.def(
368385
"fn_with_all_inputs(Tensor tensor, "
369386
"Tensor[] tensors, "
@@ -410,6 +427,7 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
410427
TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
411428
m.impl("custom_add", at::custom_add_impl);
412429
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_impl);
430+
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_impl);
413431
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
414432
m.impl("fn_with_default_input", at::fn_with_default_input_impl);
415433
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
@@ -422,6 +440,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
422440

423441
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
424442
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_meta);
443+
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_meta);
425444
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
426445
m.impl("fn_with_default_input", at::fn_with_default_input_meta);
427446
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);

test/inductor/test_aot_inductor_custom_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ def forward(self, x, y):
149149
)
150150
self.check_model(m, args)
151151

152+
def test_fn_with_optional_tensor_output_2(self) -> None:
153+
class M(torch.nn.Module):
154+
def forward(self, x, y):
155+
return torch.ops.aoti_custom_ops.fn_with_optional_tensor_output_2(x, y)
156+
157+
m = M().to(device=self.device)
158+
args = (
159+
torch.randn(3, 3, device=self.device),
160+
torch.randn(3, 3, device=self.device),
161+
)
162+
self.check_model(m, args)
163+
152164
def test_custom_op_all_inputs(self) -> None:
153165
class MyModel(torch.nn.Module):
154166
# pyre-fixme[3]: Return type must be annotated.

torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ void OSSProxyExecutor::call_function(
849849
TORCH_CHECK(false, "Expected tensor, got None");
850850
}
851851
} else {
852-
continue;
852+
index++;
853853
}
854854
} else {
855855
TORCH_CHECK(

0 commit comments

Comments
 (0)