We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d8ea4ce commit 84b91ceCopy full SHA for 84b91ce
torch/_inductor/test_operators.py
@@ -1,4 +1,5 @@
1
-# mypy: allow-untyped-defs
+from typing import Any
2
+
3
import torch.library
4
from torch import Tensor
5
from torch.autograd import Function
@@ -16,12 +17,13 @@
16
17
18
class Realize(Function):
19
@staticmethod
- def forward(ctx, x):
20
+ def forward(ctx: object, x: Tensor) -> Tensor:
21
return torch.ops._inductor_test.realize(x)
22
23
- def backward(ctx, grad_output):
24
- return grad_output
+ # types need to stay consistent with _SingleLevelFunction
25
+ def backward(ctx: Any, *grad_output: Any) -> Any:
26
+ return grad_output[0]
27
28
def realize(x: Tensor) -> Tensor:
29
return Realize.apply(x)
0 commit comments