-
Notifications
You must be signed in to change notification settings - Fork 20
Open
Description
Description of the bug
Hi @wesselb,
I am trying to write some GP code in JAX and accelerate it with jax.jit, but it is failing due to a numpy conversion happening in the process. A potential solution seems to comment out code checking for NaN values in logpdf function (and it works), but you can suggest a better solution for this. Also, chex mentions that it allows testing code with and without jitting; it could be used in testing at some point in the future.
Code
import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = jax.jit(jax.grad(loss_fn))
grad_fn(lengthscale)Output
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)
45 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
530 device=device, backend=backend, name=flat_fun.__name__,
--> 531 donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
532 out_pytree_def = out_tree()
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params)
1962 def bind(self, fun, *args, **params):
-> 1963 return call_bind(self, fun, *args, **params)
1964
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params)
1978 fun_ = lu.annotate(fun_, fun.in_type)
-> 1979 outs = top_trace.process_call(primitive, fun_, tracers, params)
1980 return map(full_lower, apply_todos(env_trace_todo(), outs))
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params)
688 def process_call(self, primitive, f, tracers, params):
--> 689 return primitive.impl(f, *tracers, **params)
690 process_map = process_call
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***)
233 compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
--> 234 keep_unused, *arg_specs)
235 try:
[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
294 else:
--> 295 ans = call(fun, *args)
296 cache[key] = (ans, fun.stores)
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
324 return lower_xla_callable(fun, device, backend, name, donated_invars, False,
--> 325 keep_unused, *arg_specs).compile().unsafe_call
326
[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
312 with TraceAnnotation(name, **decorator_kwargs):
--> 313 return func(*args, **kwargs)
314 return wrapper
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
400 jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
--> 401 fun, pe.debug_info_final(fun, "jit"))
402 out_avals, kept_outputs = util.unzip2(out_type)
[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
312 with TraceAnnotation(name, **decorator_kwargs):
--> 313 return func(*args, **kwargs)
314 return wrapper
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
2024 with core.new_sublevel():
-> 2025 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2026 del fun, main
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
1974 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 1975 ans = fun.call_wrapped(*in_tracers_)
1976 out_tracers = map(trace.full_raise, ans)
[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
167 try:
--> 168 ans = self.f(*args, **dict(self.params, **kwargs))
169 except:
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in grad_f(*args, **kwargs)
1002 def grad_f(*args, **kwargs):
-> 1003 _, g = value_and_grad_f(*args, **kwargs)
1004 return g
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in value_and_grad_f(*args, **kwargs)
1078 if not has_aux:
-> 1079 ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
1080 else:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _vjp(fun, has_aux, reduce_axes, *primals)
2497 out_primal, out_vjp = ad.vjp(
-> 2498 flat_fun, primals_flat, reduce_axes=reduce_axes)
2499 out_tree = out_tree()
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in vjp(traceable, primals, has_aux, reduce_axes)
132 if not has_aux:
--> 133 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
134 else:
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in linearize(traceable, *primals, **kwargs)
121 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 122 jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
123 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
312 with TraceAnnotation(name, **decorator_kwargs):
--> 313 return func(*args, **kwargs)
314 return wrapper
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_nounits(fun, pvals, instantiate)
768 fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 769 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
770 assert not env
[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
167 try:
--> 168 ans = self.f(*args, **dict(self.params, **kwargs))
169 except:
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
10 grad_fn = jax.jit(jax.grad(loss_fn))
[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
261 if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262 available = B.jit_to_numpy(~B.isnan(x[:, 0]))
263 if not B.all(available):
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
1532 else:
-> 1533 res = B.to_numpy(*args)
1534 if B.control_flow.caching:
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
1496 """
-> 1497 return convert(a, NPOrNum)
1498
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
31 """
---> 32 return _convert.invoke(type_of(obj), type_to)(obj, type_to)
33
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
606 def wrapped_method(*args, **kw_args):
--> 607 return _convert(method(*args, **kw_args), return_type)
608
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
60 def perform_conversion(obj: type_from, _: type_to):
---> 61 return f(obj)
62
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
535 def __array__(self, *args, **kw):
--> 536 raise TracerArrayConversionError(self)
537
UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TracerArrayConversionError Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
7 y = jnp.arange(10)
8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
10 grad_fn = jax.jit(jax.grad(loss_fn))
11 grad_fn(lengthscale)
[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
260 # Handle missing data. We don't handle missing data for batched computation.
261 if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262 available = B.jit_to_numpy(~B.isnan(x[:, 0]))
263 if not B.all(available):
264 # Take the elements of the mean, variance, and inputs corresponding to
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
582 # to speed up the common case.
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
586 return _convert(method(*args, **kw_args), return_type)
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
1531 return B.control_flow.get_outcome("to_numpy")
1532 else:
-> 1533 res = B.to_numpy(*args)
1534 if B.control_flow.caching:
1535 B.control_flow.set_outcome("to_numpy", res)
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
582 # to speed up the common case.
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
586 return _convert(method(*args, **kw_args), return_type)
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
1495 `np.ndarray`: `a` as NumPy.
1496 """
-> 1497 return convert(a, NPOrNum)
1498
1499
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
582 # to speed up the common case.
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
586 return _convert(method(*args, **kw_args), return_type)
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
30 object: `obj` converted to type `type_to`.
31 """
---> 32 return _convert.invoke(type_of(obj), type_to)(obj, type_to)
33
34
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
605 @wraps(self._f)
606 def wrapped_method(*args, **kw_args):
--> 607 return _convert(method(*args, **kw_args), return_type)
608
609 return wrapped_method
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
59 @_convert.dispatch
60 def perform_conversion(obj: type_from, _: type_to):
---> 61 return f(obj)
62
63
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Description of your environment
Tried this in Google colab.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels