Skip to content

Commit 7c4bc5b

Browse files
add jax interface
1 parent 3f2ec85 commit 7c4bc5b

File tree

4 files changed

+325
-3
lines changed

4 files changed

+325
-3
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Unreleased
44

5+
### Added
6+
7+
- Add `jax_interface`
8+
59
## 1.1.0
610

711
### Added

tensorcircuit/interfaces/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,4 @@
1414
from .scipy import scipy_interface, scipy_optimize_interface
1515
from .torch import torch_interface, pytorch_interface, torch_interface_kws
1616
from .tensorflow import tensorflow_interface, tf_interface
17-
18-
19-
# TODO(@refraction-ray): jax interface using puer_callback and custom_vjp
17+
from .jax import jax_interface

tensorcircuit/interfaces/jax.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""
2+
Interface wraps quantum function as a jax function
3+
"""
4+
5+
from typing import Any, Callable, Tuple, Optional, Union, Sequence
6+
from functools import wraps, partial
7+
8+
import jax
9+
from jax import custom_vjp
10+
11+
from ..cons import backend
12+
from .tensortrans import general_args_to_backend
13+
14+
Tensor = Any
15+
16+
17+
def jax_wrapper(
18+
fun: Callable[..., Any],
19+
enable_dlpack: bool = False,
20+
output_shape: Optional[
21+
Union[Tuple[int, ...], Tuple[int, ...], Sequence[Tuple[int, ...]]]
22+
] = None,
23+
output_dtype: Optional[Union[Any, Sequence[Any]]] = None,
24+
) -> Callable[..., Any]:
25+
@wraps(fun)
26+
def fun_jax(*x: Any) -> Any:
27+
def wrapped_fun(*args: Any) -> Any:
28+
args = general_args_to_backend(args, enable_dlpack=enable_dlpack)
29+
y = fun(*args)
30+
y = general_args_to_backend(
31+
y, target_backend="jax", enable_dlpack=enable_dlpack
32+
)
33+
return y
34+
35+
# Use provided shape and dtype if available, otherwise run test
36+
if output_shape is not None and output_dtype is not None:
37+
if isinstance(output_shape, Sequence) and not isinstance(
38+
output_shape[0], int
39+
):
40+
# Multiple outputs case
41+
out_shape = tuple(
42+
jax.ShapeDtypeStruct(s, d)
43+
for s, d in zip(output_shape, output_dtype)
44+
)
45+
else:
46+
# Single output case
47+
out_shape = jax.ShapeDtypeStruct(output_shape, output_dtype) # type: ignore
48+
else:
49+
# Get expected output shape by running function once
50+
test_out = wrapped_fun(*x)
51+
if isinstance(test_out, tuple):
52+
# Multiple outputs case
53+
out_shape = tuple(
54+
jax.ShapeDtypeStruct(
55+
t.shape if hasattr(t, "shape") else (),
56+
t.dtype if hasattr(t, "dtype") else x[0].dtype,
57+
)
58+
for t in test_out
59+
)
60+
else:
61+
# Single output case
62+
out_shape = jax.ShapeDtypeStruct( # type: ignore
63+
test_out.shape if hasattr(test_out, "shape") else (),
64+
test_out.dtype if hasattr(test_out, "dtype") else x[0].dtype,
65+
)
66+
67+
# Use pure_callback with correct output shape
68+
result = jax.pure_callback(wrapped_fun, out_shape, *x)
69+
return result
70+
71+
return fun_jax
72+
73+
74+
def jax_interface(
75+
fun: Callable[..., Any],
76+
jit: bool = False,
77+
enable_dlpack: bool = False,
78+
output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None,
79+
output_dtype: Optional[Any] = None,
80+
) -> Callable[..., Any]:
81+
"""
82+
Wrap a function on different ML backend with a jax interface.
83+
84+
:Example:
85+
86+
.. code-block:: python
87+
88+
tc.set_backend("tensorflow")
89+
90+
def f(params):
91+
c = tc.Circuit(1)
92+
c.rx(0, theta=params[0])
93+
c.ry(0, theta=params[1])
94+
return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
95+
96+
f = tc.interfaces.jax_interface(f, jit=True)
97+
98+
params = jnp.ones(2)
99+
value, grad = jax.value_and_grad(f)(params)
100+
101+
:param fun: The quantum function with tensor in and tensor out
102+
:type fun: Callable[..., Any]
103+
:param jit: whether to jit ``fun``, defaults to False
104+
:type jit: bool, optional
105+
:param enable_dlpack: whether transform tensor backend via dlpack, defaults to False
106+
:type enable_dlpack: bool, optional
107+
:param output_shape: Optional shape of the function output, defaults to None
108+
:type output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]], optional
109+
:param output_dtype: Optional dtype of the function output, defaults to None
110+
:type output_dtype: Optional[Any], optional
111+
:return: The same quantum function but now with jax array in and jax array out
112+
while AD is also supported
113+
:rtype: Callable[..., Any]
114+
"""
115+
jax_fun = create_jax_function(
116+
fun,
117+
enable_dlpack=enable_dlpack,
118+
jit=jit,
119+
output_shape=output_shape,
120+
output_dtype=output_dtype,
121+
)
122+
return jax_fun
123+
124+
125+
def create_jax_function(
126+
fun: Callable[..., Any],
127+
enable_dlpack: bool = False,
128+
jit: bool = False,
129+
output_shape: Optional[Union[Tuple[int, ...], Tuple[()]]] = None,
130+
output_dtype: Optional[Any] = None,
131+
) -> Callable[..., Any]:
132+
if jit:
133+
fun = backend.jit(fun)
134+
135+
wrapped = jax_wrapper(
136+
fun,
137+
enable_dlpack=enable_dlpack,
138+
output_shape=output_shape,
139+
output_dtype=output_dtype,
140+
)
141+
142+
@custom_vjp
143+
def f(*x: Any) -> Any:
144+
return wrapped(*x)
145+
146+
def f_fwd(*x: Any) -> Tuple[Any, Tuple[Any, ...]]:
147+
y = wrapped(*x)
148+
return y, x
149+
150+
def f_bwd(res: Tuple[Any, ...], g: Any) -> Tuple[Any, ...]:
151+
x = res
152+
153+
if len(x) == 1:
154+
x = x[0]
155+
156+
vjp_fun = partial(backend.vjp, fun)
157+
if jit:
158+
vjp_fun = backend.jit(vjp_fun) # type: ignore
159+
160+
def vjp_wrapped(args: Any) -> Any:
161+
args = general_args_to_backend(args, enable_dlpack=enable_dlpack)
162+
gb = general_args_to_backend(g, enable_dlpack=enable_dlpack)
163+
r = vjp_fun(args, gb)[1]
164+
r = general_args_to_backend(
165+
r, target_backend="jax", enable_dlpack=enable_dlpack
166+
)
167+
return r
168+
169+
# Handle gradient shape for both single input and tuple inputs
170+
if isinstance(x, tuple):
171+
# Create a tuple of ShapeDtypeStruct for each input
172+
grad_shape = tuple(jax.ShapeDtypeStruct(xi.shape, xi.dtype) for xi in x)
173+
else:
174+
grad_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
175+
176+
dx = jax.pure_callback(
177+
vjp_wrapped,
178+
grad_shape,
179+
x,
180+
)
181+
182+
if not isinstance(dx, tuple):
183+
dx = (dx,)
184+
return dx # type: ignore
185+
186+
f.defvjp(f_fwd, f_bwd)
187+
return f

tests/test_interfaces.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from scipy import optimize
99
import tensorflow as tf
1010
import jax
11+
from jax import numpy as jnp
1112

1213
thisfile = os.path.abspath(__file__)
1314
modulepath = os.path.dirname(os.path.dirname(thisfile))
@@ -427,3 +428,135 @@ def g(a, b, c):
427428
assert tc.backend.shape_tuple(a[1]) == (2, 2, 2, 2)
428429
assert tc.backend.shape_tuple(b.eval()) == (2, 2, 2, 2, 2, 2)
429430
assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)
431+
432+
433+
def test_jax_interface_basic(tfb):
434+
435+
def f(params):
436+
c = tc.Circuit(1)
437+
c.rx(0, theta=params[0])
438+
c.ry(0, theta=params[1])
439+
return tc.backend.real(c.expectation_ps(z=[0]))
440+
441+
f_jax = tc.interfaces.jax_interface(f, jit=True)
442+
params = jnp.ones(2)
443+
444+
# Test forward pass
445+
val = f_jax(params)
446+
assert isinstance(val, jnp.ndarray)
447+
np.testing.assert_allclose(val, 0.291927, atol=1e-5)
448+
449+
# Test gradient computation
450+
val, grad = jax.value_and_grad(f_jax)(params)
451+
assert isinstance(grad, jnp.ndarray)
452+
assert grad.shape == params.shape
453+
454+
455+
def test_jax_interface_multiple_inputs(tfb):
456+
457+
def f(params1, params2):
458+
c = tc.Circuit(2)
459+
c.rx(0, theta=params1[0])
460+
c.ry(1, theta=params2[0])
461+
return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
462+
463+
f_jax = tc.interfaces.jax_interface(f, jit=False)
464+
p1 = jnp.array([1.0])
465+
p2 = jnp.array([2.0])
466+
467+
# Test forward pass
468+
val = f_jax(p1, p2)
469+
assert isinstance(val, jnp.ndarray)
470+
471+
# Test gradient computation
472+
473+
val, (grad1, grad2) = jax.value_and_grad(f_jax, argnums=(0, 1))(p1, p2)
474+
assert isinstance(grad1, jnp.ndarray)
475+
assert isinstance(grad2, jnp.ndarray)
476+
assert grad1.shape == p1.shape
477+
assert grad2.shape == p2.shape
478+
479+
480+
@pytest.mark.skip(
481+
reason="might fail when testing with other function",
482+
)
483+
def test_jax_interface_jit_dlpack(tfb):
484+
485+
def f(params):
486+
c = tc.Circuit(2)
487+
c.rx(range(2), theta=params)
488+
return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
489+
490+
# Test with JIT
491+
f_jax = tc.interfaces.jax_interface(f, jit=True, enable_dlpack=True)
492+
params = jnp.array([np.pi, np.pi], dtype=jnp.float32)
493+
494+
# First call compiles
495+
val1 = f_jax(params)
496+
# Second call should be faster
497+
val2, gs = jax.value_and_grad(f_jax)(params)
498+
499+
assert isinstance(val1, jnp.ndarray)
500+
assert isinstance(gs, jnp.ndarray)
501+
np.testing.assert_allclose(val1, val2, atol=1e-5)
502+
503+
504+
def test_jax_interface_pure_callback(tfb):
505+
506+
def f(params):
507+
# Use TF operation to test pure_callback
508+
return tf.square(params)
509+
510+
def f_jax1(params):
511+
return jnp.sum(tc.interfaces.jax_interface(f)(params))
512+
513+
def f_jax2(params):
514+
return jnp.sum(
515+
tc.interfaces.jax_interface(
516+
f, jit=True, output_shape=[2], output_dtype=jnp.float32
517+
)(params)
518+
)
519+
520+
params = jnp.array([1.0, 2.0])
521+
522+
for f_jax in [f_jax1, f_jax2]:
523+
val = f_jax(params)
524+
assert isinstance(val, jnp.ndarray)
525+
np.testing.assert_allclose(val, 5.0, atol=1e-5)
526+
527+
# Test gradient
528+
grad = jax.grad(f_jax)(params)
529+
assert isinstance(grad, jnp.ndarray)
530+
np.testing.assert_allclose(grad, [2.0, 4.0], atol=1e-5)
531+
532+
533+
def test_jax_interface_multiple_outputs(tfb):
534+
535+
def f(params):
536+
# Use TF operation to test pure_callback
537+
return tf.square(params), params
538+
539+
def f_jax1(params):
540+
r = tc.interfaces.jax_interface(f)(params)
541+
return jnp.sum(r[0] + r[1] ** 2) / 2
542+
543+
def f_jax2(params):
544+
r = tc.interfaces.jax_interface(
545+
f,
546+
jit=True,
547+
output_shape=([2], [2]),
548+
output_dtype=(jnp.float32, jnp.float32),
549+
)(params)
550+
return jnp.sum(r[0] + r[1] ** 2) / 2
551+
552+
params = jnp.array([1.0, 2.0])
553+
554+
for f_jax in [f_jax1, f_jax2]:
555+
val = f_jax(params)
556+
assert isinstance(val, jnp.ndarray)
557+
np.testing.assert_allclose(val, 5.0, atol=1e-5)
558+
559+
# Test gradient
560+
grad = jax.grad(f_jax)(params)
561+
assert isinstance(grad, jnp.ndarray)
562+
np.testing.assert_allclose(grad, [2.0, 4.0], atol=1e-5)

0 commit comments

Comments
 (0)