Skip to content

Commit 11283f3

Browse files
authored
[3D Object Detection] Rename GlobalRandomFlipY -> GlobalRandomFlip (keras-team#1303)
* Rename GlobalRandomFlipY -> GlobalRandomFlip * Rename test * Switch to 3-boolean params for flip_x/y/z
1 parent eb9d363 commit 11283f3

File tree

6 files changed

+49
-17
lines changed

6 files changed

+49
-17
lines changed

examples/training/object_detection_3d/waymo/train_pillars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def pad_tensors(x):
7171

7272
# Augment the training data
7373
AUGMENTATION_LAYERS = [
74-
preprocessing3d.GlobalRandomFlipY(),
74+
preprocessing3d.GlobalRandomFlip(),
7575
preprocessing3d.GlobalRandomDroppingPoints(drop_rate=0.02),
7676
preprocessing3d.GlobalRandomRotation(max_rotation_angle_x=3.14),
7777
preprocessing3d.GlobalRandomScaling(scaling_factor_z=(0.5, 1.5)),

keras_cv/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
from keras_cv.layers.preprocessing_3d.global_random_dropping_points import (
8484
GlobalRandomDroppingPoints,
8585
)
86-
from keras_cv.layers.preprocessing_3d.global_random_flip_y import GlobalRandomFlipY
86+
from keras_cv.layers.preprocessing_3d.global_random_flip import GlobalRandomFlip
8787
from keras_cv.layers.preprocessing_3d.global_random_rotation import GlobalRandomRotation
8888
from keras_cv.layers.preprocessing_3d.global_random_scaling import GlobalRandomScaling
8989
from keras_cv.layers.preprocessing_3d.global_random_translation import (

keras_cv/layers/preprocessing_3d/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from keras_cv.layers.preprocessing_3d.global_random_dropping_points import (
2525
GlobalRandomDroppingPoints,
2626
)
27-
from keras_cv.layers.preprocessing_3d.global_random_flip_y import GlobalRandomFlipY
27+
from keras_cv.layers.preprocessing_3d.global_random_flip import GlobalRandomFlip
2828
from keras_cv.layers.preprocessing_3d.global_random_rotation import GlobalRandomRotation
2929
from keras_cv.layers.preprocessing_3d.global_random_scaling import GlobalRandomScaling
3030
from keras_cv.layers.preprocessing_3d.global_random_translation import (

keras_cv/layers/preprocessing_3d/global_random_flip_y.py renamed to keras_cv/layers/preprocessing_3d/global_random_flip.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323

2424

2525
@tf.keras.utils.register_keras_serializable(package="keras_cv")
26-
class GlobalRandomFlipY(base_augmentation_layer_3d.BaseAugmentationLayer3D):
27-
"""A preprocessing layer which flips point clouds and bounding boxes with respect to the X axis during training.
26+
class GlobalRandomFlip(base_augmentation_layer_3d.BaseAugmentationLayer3D):
27+
"""A preprocessing layer which flips point clouds and bounding boxes with respect to the specified axis during training.
2828
29-
This layer will flip the whole scene with respect to the X axis.
29+
This layer will flip the whole scene with respect to the specified axes.
30+
Note that this layer currently only supports flipping over the Y axis.
3031
During inference time, the output will be identical to input. Call the layer with `training=True` to flip the input.
3132
3233
Input shape:
@@ -42,13 +43,26 @@ class GlobalRandomFlipY(base_augmentation_layer_3d.BaseAugmentationLayer3D):
4243
Output shape:
4344
A dictionary of Tensors with the same shape as input Tensors.
4445
46+
Args:
47+
flip_x: Whether or not to flip over the X axis. Defaults to False.
48+
flip_y: Whether or not to flip over the Y axis. Defaults to True.
49+
flip_z: Whether or not to flip over the Z axis. Defaults to False.
4550
"""
4651

47-
def __init__(self, **kwargs):
48-
super().__init__(**kwargs)
52+
def __init__(self, flip_x=False, flip_y=True, flip_z=False, **kwargs):
53+
if flip_x or flip_z:
54+
raise ValueError(
55+
"GlobalRandomFlip currently only supports flipping over the Y "
56+
f"axis. Received flip_x={flip_x}, flip_y={flip_y}, flip_z={flip_z}."
57+
)
4958

50-
def get_config(self):
51-
return {}
59+
if not (flip_x or flip_y or flip_z):
60+
raise ValueError("GlobalRandomFlip must flip over at least 1 axis.")
61+
self.flip_x = flip_x
62+
self.flip_y = flip_y
63+
self.flip_z = flip_z
64+
65+
super().__init__(**kwargs)
5266

5367
def augment_point_clouds_bounding_boxes(
5468
self, point_clouds, bounding_boxes, transformation, **kwargs
@@ -93,3 +107,10 @@ def augment_point_clouds_bounding_boxes(
93107
)
94108

95109
return (point_clouds, bounding_boxes)
110+
111+
def get_config(self):
112+
return {
113+
"flip_x": self.flip_x,
114+
"flip_y": self.flip_y,
115+
"flip_z": self.flip_z,
116+
}

keras_cv/layers/preprocessing_3d/global_random_flip_y_test.py renamed to keras_cv/layers/preprocessing_3d/global_random_flip_test.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,23 @@
1515
import tensorflow as tf
1616

1717
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
18-
from keras_cv.layers.preprocessing_3d.global_random_flip_y import GlobalRandomFlipY
18+
from keras_cv.layers.preprocessing_3d.global_random_flip import GlobalRandomFlip
1919

2020
POINT_CLOUDS = base_augmentation_layer_3d.POINT_CLOUDS
2121
BOUNDING_BOXES = base_augmentation_layer_3d.BOUNDING_BOXES
2222

2323

24-
class GlobalFlippingYTest(tf.test.TestCase):
24+
class GlobalRandomFlipTest(tf.test.TestCase):
2525
def test_augment_random_point_clouds_and_bounding_boxes(self):
26-
add_layer = GlobalRandomFlipY()
26+
add_layer = GlobalRandomFlip()
2727
point_clouds = np.random.random(size=(2, 50, 10)).astype("float32")
2828
bounding_boxes = np.random.random(size=(2, 10, 7)).astype("float32")
2929
inputs = {POINT_CLOUDS: point_clouds, BOUNDING_BOXES: bounding_boxes}
3030
outputs = add_layer(inputs)
3131
self.assertNotAllClose(inputs, outputs)
3232

3333
def test_augment_specific_random_point_clouds_and_bounding_boxes(self):
34-
add_layer = GlobalRandomFlipY()
34+
add_layer = GlobalRandomFlip()
3535
point_clouds = np.array([[[1, 1, 2, 3, 4, 5, 6, 7, 8, 9]] * 2] * 2).astype(
3636
"float32"
3737
)
@@ -49,9 +49,20 @@ def test_augment_specific_random_point_clouds_and_bounding_boxes(self):
4949
self.assertAllClose(outputs[BOUNDING_BOXES], flipped_bounding_boxes)
5050

5151
def test_augment_batch_point_clouds_and_bounding_boxes(self):
52-
add_layer = GlobalRandomFlipY()
52+
add_layer = GlobalRandomFlip()
5353
point_clouds = np.random.random(size=(3, 2, 50, 10)).astype("float32")
5454
bounding_boxes = np.random.random(size=(3, 2, 10, 7)).astype("float32")
5555
inputs = {POINT_CLOUDS: point_clouds, BOUNDING_BOXES: bounding_boxes}
5656
outputs = add_layer(inputs)
5757
self.assertNotAllClose(inputs, outputs)
58+
59+
def test_noop_raises_error(self):
60+
with self.assertRaisesRegexp(ValueError, "must flip over at least 1 axis"):
61+
_ = GlobalRandomFlip(flip_x=False, flip_y=False, flip_z=False)
62+
63+
def test_flip_x_or_z_raises_error(self):
64+
with self.assertRaisesRegexp(ValueError, "only supports flipping over the Y"):
65+
_ = GlobalRandomFlip(flip_x=True, flip_y=False, flip_z=False)
66+
67+
with self.assertRaisesRegexp(ValueError, "only supports flipping over the Y"):
68+
_ = GlobalRandomFlip(flip_x=False, flip_y=False, flip_z=True)

keras_cv/layers/serialization_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,8 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
329329
{"drop_rate": 0.1},
330330
),
331331
(
332-
"GlobalRandomFlipY",
333-
cv_layers.GlobalRandomFlipY,
332+
"GlobalRandomFlip",
333+
cv_layers.GlobalRandomFlip,
334334
{},
335335
),
336336
(

0 commit comments

Comments
 (0)