Skip to content

Commit 5b890d8

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds workaround for keras.Model variable aggregation in TF 2.4.
PiperOrigin-RevId: 351446278 Change-Id: Ib6da26428e37ec0446d9182ad1fa463ed13b2063
1 parent fce05bc commit 5b890d8

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

models/toy_sources/compression_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,28 @@ def __init__(self, source, lmbda, distortion_loss, **kwargs):
1616
self.lmbda = float(lmbda)
1717
self.distortion_loss = str(distortion_loss)
1818

19+
############################################################################
20+
# In TF <= 2.4, `Model` doesn't aggregate variables from nested `Module`s.
21+
# We fall back to aggregating them the `Module` way. Note: this ignores the
22+
# `trainable` attribute of any nested `Layer`s.
23+
@property
24+
def variables(self):
25+
return tf.Module.variables.fget(self)
26+
27+
@property
28+
def trainable_variables(self):
29+
return tf.Module.trainable_variables.fget(self)
30+
31+
weights = variables
32+
trainable_weights = trainable_variables
33+
34+
# This seems to be necessary to prevent a comparison between class objects.
35+
_TF_MODULE_IGNORED_PROPERTIES = (
36+
tf.keras.Model._TF_MODULE_IGNORED_PROPERTIES.union(
37+
("_compiled_trainable_state",)
38+
))
39+
############################################################################
40+
1941
@property
2042
def ndim_source(self):
2143
return self.source.event_shape[0]

0 commit comments

Comments
 (0)