Skip to content

Commit 7c72212

Browse files
PraChetittensorflower-gardener
authored andcommitted
Adds as_gather_encoder utility.
PiperOrigin-RevId: 264910823
1 parent dbcba51 commit 7c72212

File tree

4 files changed

+47
-2
lines changed

4 files changed

+47
-2
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/encoders/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ py_library(
1919
deps = [
2020
# tensorflow dep1,
2121
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/core:core_encoder",
22+
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/core:gather_encoder",
2223
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/core:simple_encoder",
2324
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/stages:stages_impl",
2425
],
@@ -34,6 +35,7 @@ py_test(
3435
# tensorflow dep1,
3536
# python:util tensorflow dep2,
3637
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/core:core_encoder",
38+
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/core:gather_encoder",
3739
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/core:simple_encoder",
3840
],
3941
)

tensorflow_model_optimization/python/core/internal/tensor_encoding/encoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import as_gather_encoder
2021
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import as_simple_encoder
2122
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import hadamard_quantization
2223
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import identity

tensorflow_model_optimization/python/core/internal/tensor_encoding/encoders/common_encoders.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tensorflow as tf
2525

2626
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import core_encoder
27+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import gather_encoder
2728
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import simple_encoder
2829
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages import stages_impl
2930

@@ -33,8 +34,8 @@ def as_simple_encoder(encoder, tensorspec):
3334
3435
Args:
3536
encoder: An `Encoder` object to be used to encoding.
36-
tensorspec: A `TensorSpec`. The created `SimpleEncoderV2` will be
37-
constrained to only encode input values compatible with `tensorspec`.
37+
tensorspec: A `TensorSpec`. The created `SimpleEncoder` will be constrained
38+
to only encode input values compatible with `tensorspec`.
3839
3940
Returns:
4041
A `SimpleEncoder`.
@@ -50,6 +51,28 @@ def as_simple_encoder(encoder, tensorspec):
5051
return simple_encoder.SimpleEncoder(encoder, tensorspec)
5152

5253

54+
def as_gather_encoder(encoder, tensorspec):
55+
"""Wraps an `Encoder` object as a `GahterEncoder`.
56+
57+
Args:
58+
encoder: An `Encoder` object to be used to encoding.
59+
tensorspec: A `TensorSpec`. The created `GahterEncoder` will be constrained
60+
to only encode input values compatible with `tensorspec`.
61+
62+
Returns:
63+
A `GahterEncoder`.
64+
65+
Raises:
66+
TypeError:
67+
If `encoder` is not an `Encoder` or `tensorspec` is not a `TensorSpec`.
68+
"""
69+
if not isinstance(encoder, core_encoder.Encoder):
70+
raise TypeError('The encoder must be an instance of `Encoder`.')
71+
if not isinstance(tensorspec, tf.TensorSpec):
72+
raise TypeError('The tensorspec must be a tf.TensorSpec.')
73+
return gather_encoder.GatherEncoder.from_encoder(encoder, tensorspec)
74+
75+
5376
def identity():
5477
"""Returns identity `Encoder`."""
5578
return core_encoder.EncoderComposer(

tensorflow_model_optimization/python/core/internal/tensor_encoding/encoders/common_encoders_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from tensorflow.python.util import nest as core_nest
2323
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import core_encoder
24+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import gather_encoder
2425
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import simple_encoder
2526
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders import common_encoders
2627

@@ -52,6 +53,24 @@ def test_as_simple_encoder_raises_tensorspec(self, not_a_tensorspec):
5253
common_encoders.as_simple_encoder(common_encoders.identity(),
5354
not_a_tensorspec)
5455

56+
@parameterized.parameters(_ENCODER_FNS)
57+
def test_as_gather_encoder(self, encoder_fn):
58+
encoder = common_encoders.as_gather_encoder(encoder_fn(),
59+
tf.TensorSpec((2,), tf.float32))
60+
self.assertIsInstance(encoder, gather_encoder.GatherEncoder)
61+
62+
@parameterized.parameters(None, [[]], 2.0, 'string')
63+
def test_as_gather_encoder_raises_encoder(self, not_an_encoder):
64+
with self.assertRaises(TypeError):
65+
common_encoders.as_gather_encoder(not_an_encoder,
66+
tf.TensorSpec((2,), tf.float32))
67+
68+
@parameterized.parameters(None, [[]], 2.0, 'string')
69+
def test_as_gather_encoder_raises_tensorspec(self, not_a_tensorspec):
70+
with self.assertRaises(TypeError):
71+
common_encoders.as_gather_encoder(common_encoders.identity(),
72+
not_a_tensorspec)
73+
5574
def test_identity(self):
5675
encoder = common_encoders.identity()
5776
self.assertIsInstance(encoder, core_encoder.Encoder)

0 commit comments

Comments
 (0)