Skip to content

Commit ef48d38

Browse files
Johannes Ballécopybara-github
authored andcommitted
Fixes handling of input tensor rank.
PiperOrigin-RevId: 360561777 Change-Id: I92f15753c033c40d1965c8d9e1e3ce6fa0f25848
1 parent 2cdb7b5 commit ef48d38

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

tensorflow_compression/python/layers/gdn.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def __init__(self,
152152
**kwargs: Other keyword arguments passed to superclass (`Layer`).
153153
"""
154154
super().__init__(**kwargs)
155-
self.input_spec = tf.keras.layers.InputSpec(min_ndim=2)
156155
self.inverse = inverse
157156
self.rectify = rectify
158157
self.data_format = data_format
@@ -329,14 +328,13 @@ def _channel_axis(self):
329328
return {"channels_first": 1, "channels_last": -1}[self.data_format]
330329

331330
def build(self, input_shape):
332-
channel_axis = self._channel_axis
333331
input_shape = tf.TensorShape(input_shape)
334-
num_channels = input_shape[channel_axis]
332+
if input_shape.rank is None or input_shape.rank < 2:
333+
raise ValueError(f"Input tensor must have at least rank 2, received "
334+
f"shape {input_shape}.")
335+
num_channels = input_shape[self._channel_axis]
335336
if num_channels is None:
336-
raise ValueError("The channel dimension of the inputs to `GDN` "
337-
"must be defined.")
338-
self.input_spec = tf.keras.layers.InputSpec(
339-
min_ndim=2, axes={channel_axis: num_channels})
337+
raise ValueError("The channel dimension of the inputs must be defined.")
340338

341339
if self.alpha_parameter is None:
342340
initial_value = self.alpha_initializer(
@@ -367,8 +365,9 @@ def build(self, input_shape):
367365
def call(self, inputs) -> tf.Tensor:
368366
inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
369367
rank = inputs.shape.rank
370-
if rank is None:
371-
raise RuntimeError("Input tensor rank must be defined.")
368+
if rank is None or rank < 2:
369+
raise ValueError(f"Input tensor must have at least rank 2, received "
370+
f"shape {inputs.shape}.")
372371

373372
if self.rectify:
374373
inputs = tf.nn.relu(inputs)

tensorflow_compression/python/layers/signal_conv.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,6 @@ def __init__(self, filters, kernel_support,
315315
**kwargs: Keyword arguments passed to superclass (`Layer`).
316316
"""
317317
super().__init__(**kwargs)
318-
self.input_spec = tf.keras.layers.InputSpec(ndim=self._rank + 2)
319318
self.filters = filters
320319
self.kernel_support = kernel_support
321320
self.corr = corr
@@ -574,12 +573,13 @@ def _raise_notimplemented(self):
574573

575574
def build(self, input_shape):
576575
input_shape = tf.TensorShape(input_shape)
576+
if input_shape.rank != self._rank + 2:
577+
raise ValueError(f"Input tensor must have rank {self._rank + 2}, "
578+
f"received shape {input_shape}.")
577579
channel_axis = {"channels_first": 1, "channels_last": -1}[self.data_format]
578580
input_channels = input_shape[channel_axis]
579581
if input_channels is None:
580582
raise ValueError("The channel dimension of the inputs must be defined.")
581-
self.input_spec = tf.keras.layers.InputSpec(
582-
ndim=self._rank + 2, axes={channel_axis: input_channels})
583583

584584
kernel_shape = self.kernel_support + (input_channels, self.filters)
585585
if self.channel_separable:
@@ -837,7 +837,10 @@ def _up_convolve_transpose_explicit(self, inputs, kernel, prepadding):
837837
return outputs
838838

839839
def call(self, inputs) -> tf.Tensor:
840-
inputs = tf.convert_to_tensor(inputs)
840+
inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
841+
if inputs.shape.rank != self._rank + 2:
842+
raise ValueError(f"Input tensor must have rank {self._rank + 2}, "
843+
f"received shape {inputs.shape}.")
841844
outputs = inputs
842845

843846
# Not for all possible combinations of (`kernel_support`, `corr`,

tensorflow_compression/python/layers/signal_conv_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def test_invalid_data_format_raises_error(self):
3131

3232
def test_variables_are_enumerated(self):
3333
layer = signal_conv.SignalConv2D(3, 1, use_bias=True)
34-
layer.build((None, None, 2))
34+
layer.build((None, None, None, 2))
3535
self.assertLen(layer.weights, 2)
3636
self.assertLen(layer.trainable_weights, 2)
3737
weight_names = [w.name for w in layer.weights]
3838
self.assertSameElements(weight_names, ["kernel_rdft:0", "bias:0"])
3939

4040
def test_bias_variable_is_not_unnecessarily_created(self):
4141
layer = signal_conv.SignalConv2D(5, 3, use_bias=False)
42-
layer.build((None, None, 3))
42+
layer.build((None, None, None, 3))
4343
self.assertLen(layer.weights, 1)
4444
self.assertLen(layer.trainable_weights, 1)
4545
weight_names = [w.name for w in layer.weights]
@@ -49,14 +49,14 @@ def test_variables_are_not_enumerated_when_overridden(self):
4949
layer = signal_conv.SignalConv2D(1, 1)
5050
layer.kernel_parameter = [[[[1]]]]
5151
layer.bias_parameter = [0]
52-
layer.build((None, 1))
52+
layer.build((None, None, None, 1))
5353
self.assertEmpty(layer.weights)
5454
self.assertEmpty(layer.trainable_weights)
5555

5656
def test_variables_trainable_state_follows_layer(self):
5757
layer = signal_conv.SignalConv2D(1, 1, use_bias=True)
5858
layer.trainable = False
59-
layer.build((None, 1))
59+
layer.build((None, None, None, 1))
6060
self.assertLen(layer.weights, 2)
6161
self.assertEmpty(layer.trainable_weights)
6262

0 commit comments

Comments
 (0)