Skip to content

Commit cb2a94c

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[JAX] Remove references to omnistaging.
Omnistaging has been the default and only option for a long time. PiperOrigin-RevId: 455703225
1 parent a7b35af commit cb2a94c

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

spinoffs/oryx/oryx/core/trace_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import threading
1818
from typing import Any, Dict, Generator, List
1919

20-
import jax
2120
from jax import abstract_arrays
2221
from jax import api_util
2322
from jax import core as jax_core
@@ -63,8 +62,6 @@ def wrapped(*args, **kwargs):
6362
flat_args, in_tree = tree_util.tree_flatten(args)
6463
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
6564
flat_avals = safe_map(get_shaped_aval, flat_args)
66-
if not jax.config.omnistaging_enabled:
67-
raise ValueError('Oryx must be used with JAX omnistaging enabled.')
6865
if dynamic:
6966
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
7067
flat_fun,

0 commit comments

Comments
 (0)