Skip to content

Commit e7f003a

Browse files
PraChetittensorflower-gardener
authored andcommitted
Implements GatherEncoder.
GatherEncoder is a basic class for realizing encoding in the "many-to-one" case, where multiple locations hold a Tensor of the same shape and dtype, and one needs to compute their sum at a central location, while only encoded representations are communicated between the locations. PiperOrigin-RevId: 263120120
1 parent 3c8d945 commit e7f003a

File tree

6 files changed

+986
-42
lines changed

6 files changed

+986
-42
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/BUILD

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ py_library(
1111
deps = [
1212
":core_encoder",
1313
":encoding_stage",
14+
":gather_encoder",
1415
":simple_encoder",
1516
],
1617
)
@@ -63,6 +64,33 @@ py_test(
6364
],
6465
)
6566

67+
py_library(
68+
name = "gather_encoder",
69+
srcs = ["gather_encoder.py"],
70+
deps = [
71+
":core_encoder",
72+
# tensorflow dep1,
73+
# python:util tensorflow dep2,
74+
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/utils:py_utils",
75+
],
76+
)
77+
78+
py_test(
79+
name = "gather_encoder_test",
80+
size = "small",
81+
srcs = ["gather_encoder_test.py"],
82+
deps = [
83+
":core_encoder",
84+
":encoding_stage",
85+
":gather_encoder",
86+
# absl/testing:parameterized dep1,
87+
# numpy dep1,
88+
# tensorflow dep1,
89+
# python:framework_test_lib tensorflow dep2,
90+
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/testing:test_utils",
91+
],
92+
)
93+
6694
py_library(
6795
name = "simple_encoder",
6896
srcs = ["simple_encoder.py"],

tensorflow_model_optimization/python/core/internal/tensor_encoding/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core.encoding_stage import tf_style_adaptive_encoding_stage
2727
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core.encoding_stage import tf_style_encoding_stage
2828

29+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core.gather_encoder import GatherEncoder
2930
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core.simple_encoder import SimpleEncoder

0 commit comments

Comments
 (0)