@@ -69,12 +69,14 @@ class Fun(torch.autograd.Function): # type: ignore
6969 @staticmethod
7070 def forward (ctx : Any , * x : Any ) -> Any : # type: ignore
7171 # ctx.xdtype = [xi.dtype for xi in x]
72- ctx .xdtype = backend .tree_map (lambda s : s .dtype , x )
72+ ctx .save_for_backward (* x )
73+ x_detached = backend .tree_map (lambda s : s .detach (), x )
74+ ctx .xdtype = backend .tree_map (lambda s : s .dtype , x_detached )
7375 # (x, )
7476 if len (ctx .xdtype ) == 1 :
7577 ctx .xdtype = ctx .xdtype [0 ]
76- ctx .device = (backend .tree_flatten (x )[0 ][0 ]).device
77- x = general_args_to_backend (x , enable_dlpack = enable_dlpack )
78+ ctx .device = (backend .tree_flatten (x_detached )[0 ][0 ]).device
79+ x = general_args_to_backend (x_detached , enable_dlpack = enable_dlpack )
7880 y = fun (* x )
7981 ctx .ydtype = backend .tree_map (lambda s : s .dtype , y )
8082 if len (x ) == 1 :
@@ -88,6 +90,9 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
8890
8991 @staticmethod
9092 def backward (ctx : Any , * grad_y : Any ) -> Any :
93+ x = ctx .saved_tensors
94+ x_detached = backend .tree_map (lambda s : s .detach (), x )
95+ x_backend = general_args_to_backend (x_detached , enable_dlpack = enable_dlpack )
9196 if len (grad_y ) == 1 :
9297 grad_y = grad_y [0 ]
9398 grad_y = backend .tree_map (lambda s : s .contiguous (), grad_y )
@@ -96,7 +101,12 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
96101 )
97102 # grad_y = general_args_to_numpy(grad_y)
98103 # grad_y = numpy_args_to_backend(grad_y, dtype=ctx.ydtype) # backend.dtype
99- _ , g = vjp_fun (ctx .x , grad_y )
104+ if len (x_backend ) == 1 :
105+ x_backend_for_vjp = x_backend [0 ]
106+ else :
107+ x_backend_for_vjp = x_backend
108+
109+ _ , g = vjp_fun (x_backend_for_vjp , grad_y )
100110 # a redundency due to current vjp API
101111
102112 r = general_args_to_backend (
0 commit comments