Skip to content

Commit 0f89bb7

Browse files
fix jax dlpack and torch interface with detach
1 parent a541c19 commit 0f89bb7

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212

1313
- Add `su4` as a generic parameterized two-qubit gates.
1414

15+
### Fixed
16+
17+
- Fix the breaking logic change in jax from dlpack API, dlcapsule -> tensor.
18+
19+
- Better torch interface for dlpack translation.
20+
1521
## v1.4.0
1622

1723
### Added

tensorcircuit/interfaces/tensortrans.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,17 @@ def general_args_to_backend(
132132
target_backend = backend
133133
elif isinstance(target_backend, str):
134134
target_backend = get_backend(target_backend)
135+
try:
136+
t = backend.tree_map(target_backend.from_dlpack, caps)
137+
except TypeError:
138+
t = backend.tree_map(target_backend.from_dlpack, args)
139+
135140
if dtype is None:
136-
return backend.tree_map(target_backend.from_dlpack, caps)
141+
return t
137142
if isinstance(dtype, str):
138143
leaves, treedef = backend.tree_flatten(args)
139144
dtype = [dtype for _ in range(len(leaves))]
140145
dtype = backend.tree_unflatten(treedef, dtype)
141-
t = backend.tree_map(target_backend.from_dlpack, caps)
142146
t = backend.tree_map(target_backend.cast, t, dtype)
143147
return t
144148

tensorcircuit/interfaces/torch.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,14 @@ class Fun(torch.autograd.Function): # type: ignore
6969
@staticmethod
7070
def forward(ctx: Any, *x: Any) -> Any: # type: ignore
7171
# ctx.xdtype = [xi.dtype for xi in x]
72-
ctx.xdtype = backend.tree_map(lambda s: s.dtype, x)
72+
ctx.save_for_backward(*x)
73+
x_detached = backend.tree_map(lambda s: s.detach(), x)
74+
ctx.xdtype = backend.tree_map(lambda s: s.dtype, x_detached)
7375
# (x, )
7476
if len(ctx.xdtype) == 1:
7577
ctx.xdtype = ctx.xdtype[0]
76-
ctx.device = (backend.tree_flatten(x)[0][0]).device
77-
x = general_args_to_backend(x, enable_dlpack=enable_dlpack)
78+
ctx.device = (backend.tree_flatten(x_detached)[0][0]).device
79+
x = general_args_to_backend(x_detached, enable_dlpack=enable_dlpack)
7880
y = fun(*x)
7981
ctx.ydtype = backend.tree_map(lambda s: s.dtype, y)
8082
if len(x) == 1:
@@ -88,6 +90,9 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
8890

8991
@staticmethod
9092
def backward(ctx: Any, *grad_y: Any) -> Any:
93+
x = ctx.saved_tensors
94+
x_detached = backend.tree_map(lambda s: s.detach(), x)
95+
x_backend = general_args_to_backend(x_detached, enable_dlpack=enable_dlpack)
9196
if len(grad_y) == 1:
9297
grad_y = grad_y[0]
9398
grad_y = backend.tree_map(lambda s: s.contiguous(), grad_y)
@@ -96,7 +101,12 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
96101
)
97102
# grad_y = general_args_to_numpy(grad_y)
98103
# grad_y = numpy_args_to_backend(grad_y, dtype=ctx.ydtype) # backend.dtype
99-
_, g = vjp_fun(ctx.x, grad_y)
104+
if len(x_backend) == 1:
105+
x_backend_for_vjp = x_backend[0]
106+
else:
107+
x_backend_for_vjp = x_backend
108+
109+
_, g = vjp_fun(x_backend_for_vjp, grad_y)
100110
# a redundency due to current vjp API
101111

102112
r = general_args_to_backend(

0 commit comments

Comments
 (0)