Skip to content

Commit 575df27

Browse files
Johannes Ballécopybara-github
authored andcommitted
Updates initializers to TF2/Keras API and adds unit tests.
PiperOrigin-RevId: 354464149 Change-Id: Ibd3fdfbc0e581c15b1bd9c4404e9b159330b6a20
1 parent 994ca84 commit 575df27

File tree

4 files changed

+74
-14
lines changed

4 files changed

+74
-14
lines changed

tensorflow_compression/all_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from tensorflow_compression.python.layers.entropy_models_test import *
3333
from tensorflow_compression.python.layers.gdn_test import *
34+
from tensorflow_compression.python.layers.initializers_test import *
3435
from tensorflow_compression.python.layers.parameterizers_test import *
3536
from tensorflow_compression.python.layers.signal_conv_test import *
3637
from tensorflow_compression.python.layers.soft_round_test import *

tensorflow_compression/python/layers/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ py_test(
6868
deps = [":entropy_models"],
6969
)
7070

71+
py_test(
72+
name = "initializers_test",
73+
srcs = ["initializers_test.py"],
74+
python_version = "PY3",
75+
deps = [":initializers"],
76+
)
77+
7178
py_test(
7279
name = "parameterizers_test",
7380
srcs = ["parameterizers_test.py"],

tensorflow_compression/python/layers/initializers.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,42 @@
1414
# ==============================================================================
1515
"""Initializers for layer classes."""
1616

17-
import tensorflow.compat.v1 as tf
17+
import tensorflow as tf
1818

1919

2020
__all__ = [
2121
"IdentityInitializer",
2222
]
2323

2424

25-
class IdentityInitializer(object):
25+
class IdentityInitializer(tf.keras.initializers.Initializer):
2626
"""Initialize to the identity kernel with the given shape.
2727
2828
This creates an n-D kernel suitable for `SignalConv*` with the requested
2929
support that produces an output identical to its input (except possibly at the
3030
signal boundaries).
3131
32-
Note: The identity initializer in `tf.initializers` is only suitable for
32+
Note: The identity initializer in `tf.keras.initializers` is only suitable for
3333
matrices, not for n-D convolution kernels (i.e., no spatial support).
3434
"""
3535

3636
def __init__(self, gain=1):
37-
self.gain = float(gain)
37+
super().__init__()
38+
self.gain = gain
3839

39-
def __call__(self, shape, dtype=None, partition_info=None):
40-
del partition_info # unused
41-
assert len(shape) > 2, shape
40+
def __call__(self, shape, dtype=None, **kwargs):
41+
del kwargs # unused
42+
shape = tf.TensorShape(shape)
43+
if shape.rank <= 2:
44+
raise ValueError(f"shape must be at least rank 3, got {shape}.")
4245

43-
support = tuple(shape[:-2]) + (1, 1)
46+
support = shape.as_list()[:-2] + [1, 1]
4447
indices = [[s // 2 for s in support]]
4548
updates = tf.constant([self.gain], dtype=dtype)
46-
kernel = tf.scatter_nd(indices, updates, support)
49+
spatial_kernel = tf.scatter_nd(indices, updates, support)
50+
return spatial_kernel * tf.eye(shape[-2], shape[-1], dtype=dtype)
4751

48-
assert shape[-2] == shape[-1], shape
49-
if shape[-1] != 1:
50-
kernel *= tf.eye(shape[-1], dtype=dtype)
51-
52-
return kernel
52+
def get_config(self):
53+
config = super().get_config()
54+
config.update(gain=self.gain)
55+
return config
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2021 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests of initializers."""
16+
17+
import tensorflow as tf
18+
from tensorflow_compression.python.layers import initializers
19+
20+
21+
class InitializerTest(tf.test.TestCase):
22+
23+
def test_creates_1d_kernel(self):
24+
expected_kernel = tf.transpose([
25+
[[0, 3, 0], [0, 0, 0], [0, 0, 0]],
26+
[[0, 0, 0], [0, 3, 0], [0, 0, 0]],
27+
[[0, 0, 0], [0, 0, 0], [0, 3, 0]],
28+
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
29+
], (2, 0, 1))
30+
kernel = initializers.IdentityInitializer(gain=3)((3, 4, 3), dtype=tf.int32)
31+
self.assertAllEqual(expected_kernel, kernel)
32+
33+
def test_creates_2d_kernel(self):
34+
expected_kernel = tf.constant([
35+
[0, 0, 0, 0, 0],
36+
[0, 0, 0, 0, 0],
37+
[0, 0, 1, 0, 0],
38+
[0, 0, 0, 0, 0],
39+
])[:, :, None, None]
40+
kernel = initializers.IdentityInitializer()((4, 5, 1, 1), dtype=tf.float32)
41+
self.assertAllEqual(expected_kernel, kernel)
42+
43+
def test_fails_for_invalid_shape(self):
44+
with self.assertRaises(ValueError):
45+
initializers.IdentityInitializer()((2, 3), dtype=tf.float32)
46+
47+
48+
if __name__ == "__main__":
49+
tf.test.main()

0 commit comments

Comments
 (0)