We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a7b35af commit cb2a94cCopy full SHA for cb2a94c
spinoffs/oryx/oryx/core/trace_util.py
@@ -17,7 +17,6 @@
17
import threading
18
from typing import Any, Dict, Generator, List
19
20
-import jax
21
from jax import abstract_arrays
22
from jax import api_util
23
from jax import core as jax_core
@@ -63,8 +62,6 @@ def wrapped(*args, **kwargs):
63
62
flat_args, in_tree = tree_util.tree_flatten(args)
64
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
65
flat_avals = safe_map(get_shaped_aval, flat_args)
66
- if not jax.config.omnistaging_enabled:
67
- raise ValueError('Oryx must be used with JAX omnistaging enabled.')
68
if dynamic:
69
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
70
flat_fun,
0 commit comments