Skip to content

Commit 89a4efe

Browse files
Johannes Ballécopybara-github
authored andcommitted
Updates GDN/SignalConv to TF2 API and refactors reparameterizers.
From TF 2.5, we can rely on tf.Variables contained in Layer attributes (e.g. tf.Modules) to show up recursively in self.weights/self.trainable_weights. We can use this here to simplify reparameterization. With this change, we don't rely on Layer.add_weight() any more. Instead, we just assign a tf.Variable as an attribute of the Layer. Alternatively, we allow the attribute to be a callable (e.g. for conditioning kernels on some other computation), or a special callable 'Parameter' which implements reparameterization. Parameter derives from tf.Module. As a result, the contained reparameterized tf.Variable is associated with the Layer, and will be trained. PiperOrigin-RevId: 355009145 Change-Id: I9936142d0ce999b3ebc8f0e8e50ef6cc0a0b3e20
1 parent 89eefc8 commit 89a4efe

File tree

13 files changed

+1184
-689
lines changed

13 files changed

+1184
-689
lines changed

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ py_library(
1717
"//tensorflow_compression/python/entropy_models:universal",
1818
"//tensorflow_compression/python/layers:gdn",
1919
"//tensorflow_compression/python/layers:initializers",
20-
"//tensorflow_compression/python/layers:parameterizers",
20+
"//tensorflow_compression/python/layers:parameters",
2121
"//tensorflow_compression/python/layers:signal_conv",
2222
"//tensorflow_compression/python/layers:soft_round",
2323
"//tensorflow_compression/python/ops:math_ops",

tensorflow_compression/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from tensorflow_compression.python.layers.gdn import *
3636
from tensorflow_compression.python.layers.initializers import *
37-
from tensorflow_compression.python.layers.parameterizers import *
37+
from tensorflow_compression.python.layers.parameters import *
3838
from tensorflow_compression.python.layers.signal_conv import *
3939
from tensorflow_compression.python.layers.soft_round import *
4040

tensorflow_compression/all_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from tensorflow_compression.python.layers.gdn_test import *
3333
from tensorflow_compression.python.layers.initializers_test import *
34-
from tensorflow_compression.python.layers.parameterizers_test import *
34+
from tensorflow_compression.python.layers.parameters_test import *
3535
from tensorflow_compression.python.layers.signal_conv_test import *
3636
from tensorflow_compression.python.layers.soft_round_test import *
3737

tensorflow_compression/python/layers/BUILD

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ py_library(
88
name = "gdn",
99
srcs = ["gdn.py"],
1010
srcs_version = "PY3",
11-
deps = [":parameterizers"],
11+
deps = [":parameters"],
1212
)
1313

1414
py_test(
1515
name = "gdn_test",
1616
srcs = ["gdn_test.py"],
1717
python_version = "PY3",
18-
deps = [":gdn"],
18+
deps = [
19+
":gdn",
20+
":parameters",
21+
],
1922
)
2023

2124
py_library(
@@ -24,47 +27,40 @@ py_library(
2427
srcs_version = "PY3",
2528
)
2629

30+
py_test(
31+
name = "initializers_test",
32+
srcs = ["initializers_test.py"],
33+
python_version = "PY3",
34+
deps = [":initializers"],
35+
)
36+
2737
py_library(
28-
name = "parameterizers",
29-
srcs = ["parameterizers.py"],
38+
name = "parameters",
39+
srcs = ["parameters.py"],
3040
srcs_version = "PY3",
3141
deps = [
3242
"//tensorflow_compression/python/ops:math_ops",
3343
"//tensorflow_compression/python/ops:spectral_ops",
3444
],
3545
)
3646

47+
py_test(
48+
name = "parameters_test",
49+
srcs = ["parameters_test.py"],
50+
python_version = "PY3",
51+
deps = [":parameters"],
52+
)
53+
3754
py_library(
3855
name = "signal_conv",
3956
srcs = ["signal_conv.py"],
4057
srcs_version = "PY3",
4158
deps = [
42-
":parameterizers",
59+
":parameters",
4360
"//tensorflow_compression/python/ops:padding_ops",
4461
],
4562
)
4663

47-
py_library(
48-
name = "soft_round",
49-
srcs = ["soft_round.py"],
50-
srcs_version = "PY3",
51-
deps = ["//tensorflow_compression/python/ops:soft_round_ops"],
52-
)
53-
54-
py_test(
55-
name = "initializers_test",
56-
srcs = ["initializers_test.py"],
57-
python_version = "PY3",
58-
deps = [":initializers"],
59-
)
60-
61-
py_test(
62-
name = "parameterizers_test",
63-
srcs = ["parameterizers_test.py"],
64-
python_version = "PY3",
65-
deps = [":parameterizers"],
66-
)
67-
6864
py_test(
6965
name = "signal_conv_test",
7066
timeout = "long",
@@ -73,16 +69,26 @@ py_test(
7369
shard_count = 3,
7470
deps = [
7571
":initializers",
76-
":parameterizers",
72+
":parameters",
7773
":signal_conv",
7874
],
7975
)
8076

77+
py_library(
78+
name = "soft_round",
79+
srcs = ["soft_round.py"],
80+
srcs_version = "PY3",
81+
deps = ["//tensorflow_compression/python/ops:soft_round_ops"],
82+
)
83+
8184
py_test(
8285
name = "soft_round_test",
8386
srcs = ["soft_round_test.py"],
8487
python_version = "PY3",
85-
deps = [":soft_round"],
88+
deps = [
89+
":soft_round",
90+
"//tensorflow_compression/python/ops:soft_round_ops",
91+
],
8692
)
8793

8894
filegroup(

0 commit comments

Comments
 (0)