Skip to content

Commit 45f1a86

Browse files
authored
merge g3 cl (#8138)
1 parent 7a1cc67 commit 45f1a86

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

tfjs-converter/python/tensorflowjs/converters/jax_conversion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717
from typing import Any, Callable, Optional, Sequence, Tuple, Union
1818

1919
from jax.experimental import jax2tf
20-
from jax.experimental.export import shape_poly
2120
import tensorflow as tf
2221
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion
2322

2423

2524
_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
2625
Array = Any
2726
DType = Any
28-
PolyShape = shape_poly.PolyShape
2927

3028

3129
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
@@ -60,7 +58,7 @@ def convert_jax(
6058
*,
6159
input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]],
6260
model_dir: str,
63-
polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None,
61+
polymorphic_shapes: Optional[Sequence[str]] = None,
6462
**tfjs_converter_params):
6563
"""Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model.
6664

0 commit comments

Comments
 (0)