Skip to content

Unable to jit logpdf #21

@patel-zeel

Description

@patel-zeel

Description of the bug

Hi @wesselb,

I am trying to write some GP code in JAX and accelerate it with jax.jit, but it is failing due to a numpy conversion happening in the process. A potential solution seems to comment out code checking for NaN values in logpdf function (and it works), but you can suggest a better solution for this. Also, chex mentions that it allows testing code with and without jitting; it could be used in testing at some point in the future.

Code

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ

x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = jax.jit(jax.grad(loss_fn))
grad_fn(lengthscale)

Output

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

45 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
    530         device=device, backend=backend, name=flat_fun.__name__,
--> 531         donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
    532     out_pytree_def = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params)
   1962   def bind(self, fun, *args, **params):
-> 1963     return call_bind(self, fun, *args, **params)
   1964 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params)
   1978   fun_ = lu.annotate(fun_, fun.in_type)
-> 1979   outs = top_trace.process_call(primitive, fun_, tracers, params)
   1980   return map(full_lower, apply_todos(env_trace_todo(), outs))

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params)
    688   def process_call(self, primitive, f, tracers, params):
--> 689     return primitive.impl(f, *tracers, **params)
    690   process_map = process_call

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***)
    233   compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
--> 234                               keep_unused, *arg_specs)
    235   try:

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
    294     else:
--> 295       ans = call(fun, *args)
    296       cache[key] = (ans, fun.stores)

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    324     return lower_xla_callable(fun, device, backend, name, donated_invars, False,
--> 325                               keep_unused, *arg_specs).compile().unsafe_call
    326 

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
    400     jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
--> 401         fun, pe.debug_info_final(fun, "jit"))
    402   out_avals, kept_outputs = util.unzip2(out_type)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
   2024     with core.new_sublevel():
-> 2025       jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2026     del fun, main

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   1974     in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 1975     ans = fun.call_wrapped(*in_tracers_)
   1976     out_tracers = map(trace.full_raise, ans)

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in grad_f(*args, **kwargs)
   1002   def grad_f(*args, **kwargs):
-> 1003     _, g = value_and_grad_f(*args, **kwargs)
   1004     return g

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in value_and_grad_f(*args, **kwargs)
   1078     if not has_aux:
-> 1079       ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
   1080     else:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _vjp(fun, has_aux, reduce_axes, *primals)
   2497     out_primal, out_vjp = ad.vjp(
-> 2498         flat_fun, primals_flat, reduce_axes=reduce_axes)
   2499     out_tree = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in vjp(traceable, primals, has_aux, reduce_axes)
    132   if not has_aux:
--> 133     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    134   else:

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in linearize(traceable, *primals, **kwargs)
    121   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 122   jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
    123   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_nounits(fun, pvals, instantiate)
    768     fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 769     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    770     assert not env

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
    535   def __array__(self, *args, **kw):
--> 536     raise TracerArrayConversionError(self)
    537 

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
      9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      7 y = jnp.arange(10)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
     11 grad_fn(lengthscale)

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    260         # Handle missing data. We don't handle missing data for batched computation.
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):
    264                 # Take the elements of the mean, variance, and inputs corresponding to

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1531         return B.control_flow.get_outcome("to_numpy")
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:
   1535             B.control_flow.set_outcome("to_numpy", res)

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1495         `np.ndarray`: `a` as NumPy.
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 
   1499 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     30         object: `obj` converted to type `type_to`.
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 
     34 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    605         @wraps(self._f)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 
    609         return wrapped_method

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     59     @_convert.dispatch
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 
     63 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Description of your environment

Tried this in Google colab.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions