Skip to content

Commit c954287

Browse files
Xharktensorflower-gardener
authored andcommitted
Change the compress, training function input argument format as same as decompress.
PiperOrigin-RevId: 338989166
1 parent a05a2cc commit c954287

File tree

5 files changed

+15
-20
lines changed

5 files changed

+15
-20
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def init_training_weights_repr(
6262
`tf.keras.layers.Layer.add_weight`for each tf.Variable to create.
6363
"""
6464

65-
def compress(self, training_weights: List[tf.Tensor]) -> List[tf.Tensor]:
65+
def compress(self, *training_weights: tf.Tensor) -> List[tf.Tensor]:
6666
"""Define the operations to compress a single weight after training.
6767
6868
'Compress' can refer to making the weight more amenable to compression
@@ -71,22 +71,22 @@ def compress(self, training_weights: List[tf.Tensor]) -> List[tf.Tensor]:
7171
The default is an identity.
7272
7373
Args:
74-
training_weights: tf.Tensors representing all variables used during
74+
*training_weights: tf.Tensors representing all variables used during
7575
training, for a single compressible weight, in the order returned in
7676
`init_training_weights_repr`.
7777
7878
Returns:
7979
List of tf.Tensors to set to compressed or more compressible form.
8080
"""
81-
return training_weights
81+
return list(training_weights)
8282

83-
def decompress(self, compressed_weights: List[tf.Tensor]) -> tf.Tensor:
83+
def decompress(self, *compressed_weights: tf.Tensor) -> tf.Tensor:
8484
"""Define the operations to decompress a single weight’s compressed form during inference.
8585
8686
The default is an identity. TODO(): actually isn't.
8787
8888
Args:
89-
compressed_weights: tf.Tensors representing a single weight’s compressed
89+
*compressed_weights: tf.Tensors representing a single weight’s compressed
9090
form, coming from what’s returned in `compress`.
9191
9292
Returns:
@@ -95,14 +95,14 @@ def decompress(self, compressed_weights: List[tf.Tensor]) -> tf.Tensor:
9595
return compressed_weights[0]
9696

9797
@abc.abstractmethod
98-
def training(self, training_weights: List[tf.Tensor]) -> tf.Tensor:
98+
def training(self, *training_weights: tf.Tensor) -> tf.Tensor:
9999
"""Define a piece of the forward pass during training, which operates on a single compressible weight.
100100
101101
TODO(tfmot): throw this error.
102102
The default throws an error when training occurs.
103103
104104
Args:
105-
training_weights: tf.Tensors representing any variables used during
105+
*training_weights: tf.Tensors representing any variables used during
106106
training, for a single compressible weight, in the order returned in
107107
`init_training_weights_repr`.
108108

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def init_training_weights_repr(
4848
def decompress(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
4949
return tf.matmul(u, sv)
5050

51-
def compress(self, training_weights: List[tf.Tensor]) -> List[tf.Tensor]:
52-
assert len(training_weights) == 1
53-
weight = training_weights[0]
54-
51+
def compress(self, weight: tf.Tensor) -> List[tf.Tensor]:
5552
rank = self.params.rank
5653
s, u, v = tf.linalg.svd(weight)
5754

@@ -73,8 +70,8 @@ def compress(self, training_weights: List[tf.Tensor]) -> List[tf.Tensor]:
7370
return [u, sv]
7471

7572
# TODO(tfmot): remove in this example, which is just post-training.
76-
def training(self, training_weights: List[tf.Tensor]) -> tf.Tensor:
77-
return training_weights[0]
73+
def training(self, weight: tf.Tensor) -> tf.Tensor:
74+
return weight
7875

7976

8077
# TODO(tfmot): consider if we can simplify `create_model_for_training` and

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def testSVD_PreservesPretrainedWeights(self):
234234

235235
# kernel
236236
algorithm = svd.SVD(params)
237-
w1, w2 = algorithm.compress([tf.constant(dense_layer_weights[0])])
237+
w1, w2 = algorithm.compress(tf.constant(dense_layer_weights[0]))
238238
assert (w1 == dense_layer_compressed_weights[0]).numpy().all()
239239
assert (w2 == dense_layer_compressed_weights[1]).numpy().all()
240240

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def init_training_weights_repr(
7373
def decompress(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
7474
return tf.matmul(u, sv)
7575

76-
def training(self, training_weights: List[tf.Tensor]) -> tf.Tensor:
77-
u = training_weights[0]
78-
sv = training_weights[1]
76+
def training(self, u: tf.Tensor, sv: tf.Tensor) -> tf.Tensor:
7977
return self.decompress(u, sv)
8078

8179

tensorflow_model_optimization/python/core/common/keras/compression/internal/optimize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def call(self, inputs):
119119
training_weight_tensors.append(
120120
_prevent_constant_folding(v.read_value(), inputs))
121121

122-
weight_tensor = self.algorithm.training(training_weight_tensors)
122+
weight_tensor = self.algorithm.training(*training_weight_tensors)
123123
setattr(self.layer, attr_name, weight_tensor)
124124

125125
# This assumes that all changes to the forward pass happen "prior" to
@@ -186,7 +186,7 @@ def build(self, input_shape):
186186
self.compressed_weights = {}
187187
for attr_name in self.training_tensors:
188188
training_tensors = self.training_tensors[attr_name]
189-
compressed_tensors = self.algorithm.compress(training_tensors)
189+
compressed_tensors = self.algorithm.compress(*training_tensors)
190190
weights = []
191191
for t in compressed_tensors:
192192
weight = self.add_weight(name='TODO', shape=t.shape)
@@ -297,7 +297,7 @@ def _map_to_inference_weights(training_weights, algorithm, training_tensors):
297297
layer_weights_i = 0
298298
for weight in weights:
299299
if weight in training_tensors:
300-
compressed = algorithm.compress(training_tensors[weight])
300+
compressed = algorithm.compress(*training_tensors[weight])
301301
for c in compressed:
302302
compressed_weights.append(c.numpy())
303303
layer_weights_i += len(training_tensors[weight])

0 commit comments

Comments
 (0)