@@ -44,116 +44,110 @@ class GlobalRandomScaling(base_augmentation_layer_3d.BaseAugmentationLayer3D):
4444 A dictionary of Tensors with the same shape as input Tensors.
4545
4646 Arguments:
47- scaling_factor_x : A tuple of float scalars or a float scalar sets the minimum and maximum scaling factors for the X axis.
48- scaling_factor_y : A tuple of float scalars or a float scalar sets the minimum and maximum scaling factors for the Y axis.
49- scaling_factor_z : A tuple of float scalars or a float scalar sets the minimum and maximum scaling factors for the Z axis.
47+ x_factor : A tuple of float scalars or a float scalar sets the minimum and maximum scaling factors for the X axis.
48+ y_factor : A tuple of float scalars or a float scalar sets the minimum and maximum scaling factors for the Y axis.
49+ z_factor : A tuple of float scalars or a float scalar sets the minimum and maximum scaling factors for the Z axis.
5050 """
5151
5252 def __init__ (
5353 self ,
54- scaling_factor_x = None ,
55- scaling_factor_y = None ,
56- scaling_factor_z = None ,
57- same_scaling_xyz = False ,
54+ x_factor = None ,
55+ y_factor = None ,
56+ z_factor = None ,
57+ preserve_aspect_ratio = False ,
5858 ** kwargs
5959 ):
6060 super ().__init__ (** kwargs )
61- if not scaling_factor_x :
62- min_scaling_factor_x = 1.0
63- max_scaling_factor_x = 1.0
64- elif type (scaling_factor_x ) is float :
65- min_scaling_factor_x = scaling_factor_x
66- max_scaling_factor_x = scaling_factor_x
61+ if not x_factor :
62+ min_x_factor = 1.0
63+ max_x_factor = 1.0
64+ elif type (x_factor ) is float :
65+ min_x_factor = x_factor
66+ max_x_factor = x_factor
6767 else :
68- min_scaling_factor_x = scaling_factor_x [0 ]
69- max_scaling_factor_x = scaling_factor_x [1 ]
70- if not scaling_factor_y :
71- min_scaling_factor_y = 1.0
72- max_scaling_factor_y = 1.0
73- elif type (scaling_factor_y ) is float :
74- min_scaling_factor_y = scaling_factor_y
75- max_scaling_factor_y = scaling_factor_y
68+ min_x_factor = x_factor [0 ]
69+ max_x_factor = x_factor [1 ]
70+ if not y_factor :
71+ min_y_factor = 1.0
72+ max_y_factor = 1.0
73+ elif type (y_factor ) is float :
74+ min_y_factor = y_factor
75+ max_y_factor = y_factor
7676 else :
77- min_scaling_factor_y = scaling_factor_y [0 ]
78- max_scaling_factor_y = scaling_factor_y [1 ]
79- if not scaling_factor_z :
80- min_scaling_factor_z = 1.0
81- max_scaling_factor_z = 1.0
82- elif type (scaling_factor_z ) is float :
83- min_scaling_factor_z = scaling_factor_z
84- max_scaling_factor_z = scaling_factor_z
77+ min_y_factor = y_factor [0 ]
78+ max_y_factor = y_factor [1 ]
79+ if not z_factor :
80+ min_z_factor = 1.0
81+ max_z_factor = 1.0
82+ elif type (z_factor ) is float :
83+ min_z_factor = z_factor
84+ max_z_factor = z_factor
8585 else :
86- min_scaling_factor_z = scaling_factor_z [0 ]
87- max_scaling_factor_z = scaling_factor_z [1 ]
86+ min_z_factor = z_factor [0 ]
87+ max_z_factor = z_factor [1 ]
8888
8989 if (
90- min_scaling_factor_x < 0
91- or max_scaling_factor_x < 0
92- or min_scaling_factor_y < 0
93- or max_scaling_factor_y < 0
94- or min_scaling_factor_z < 0
95- or max_scaling_factor_z < 0
90+ min_x_factor < 0
91+ or max_x_factor < 0
92+ or min_y_factor < 0
93+ or max_y_factor < 0
94+ or min_z_factor < 0
95+ or max_z_factor < 0
9696 ):
97- raise ValueError ("min_scaling_factor and max_scaling_factor must be >=0." )
97+ raise ValueError ("min_factor and max_factor must be >=0." )
9898 if (
99- min_scaling_factor_x > max_scaling_factor_x
100- or min_scaling_factor_y > max_scaling_factor_y
101- or min_scaling_factor_z > max_scaling_factor_z
99+ min_x_factor > max_x_factor
100+ or min_y_factor > max_y_factor
101+ or min_z_factor > max_z_factor
102102 ):
103- raise ValueError ("min_scaling_factor must be less than max_scaling_factor." )
104- if same_scaling_xyz :
105- if (
106- min_scaling_factor_x != min_scaling_factor_y
107- or min_scaling_factor_y != min_scaling_factor_z
108- ):
103+ raise ValueError ("min_factor must be less than max_factor." )
104+ if preserve_aspect_ratio :
105+ if min_x_factor != min_y_factor or min_y_factor != min_z_factor :
109106 raise ValueError (
110- "min_scaling_factor must be the same when same_scaling_xyz is true."
107+ "min_factor must be the same when preserve_aspect_ratio is true."
111108 )
112- if (
113- max_scaling_factor_x != max_scaling_factor_y
114- or max_scaling_factor_y != max_scaling_factor_z
115- ):
109+ if max_x_factor != max_y_factor or max_y_factor != max_z_factor :
116110 raise ValueError (
117- "max_scaling_factor must be the same when same_scaling_xyz is true."
111+ "max_factor must be the same when preserve_aspect_ratio is true."
118112 )
119113
120- self ._min_scaling_factor_x = min_scaling_factor_x
121- self ._max_scaling_factor_x = max_scaling_factor_x
122- self ._min_scaling_factor_y = min_scaling_factor_y
123- self ._max_scaling_factor_y = max_scaling_factor_y
124- self ._min_scaling_factor_z = min_scaling_factor_z
125- self ._max_scaling_factor_z = max_scaling_factor_z
126- self ._same_scaling_xyz = same_scaling_xyz
114+ self ._min_x_factor = min_x_factor
115+ self ._max_x_factor = max_x_factor
116+ self ._min_y_factor = min_y_factor
117+ self ._max_y_factor = max_y_factor
118+ self ._min_z_factor = min_z_factor
119+ self ._max_z_factor = max_z_factor
120+ self ._preserve_aspect_ratio = preserve_aspect_ratio
127121
128122 def get_config (self ):
129123 return {
130- "scaling_factor_x " : (
131- self ._min_scaling_factor_x ,
132- self ._max_scaling_factor_x ,
124+ "x_factor " : (
125+ self ._min_x_factor ,
126+ self ._max_x_factor ,
133127 ),
134- "scaling_factor_y " : (
135- self ._min_scaling_factor_y ,
136- self ._max_scaling_factor_y ,
128+ "y_factor " : (
129+ self ._min_y_factor ,
130+ self ._max_y_factor ,
137131 ),
138- "scaling_factor_z " : (
139- self ._min_scaling_factor_z ,
140- self ._max_scaling_factor_z ,
132+ "z_factor " : (
133+ self ._min_z_factor ,
134+ self ._max_z_factor ,
141135 ),
142- "same_scaling_xyz " : self ._same_scaling_xyz ,
136+ "preserve_aspect_ratio " : self ._preserve_aspect_ratio ,
143137 }
144138
145139 def get_random_transformation (self , ** kwargs ):
146140
147141 random_scaling_x = self ._random_generator .random_uniform (
148- (), minval = self ._min_scaling_factor_x , maxval = self ._max_scaling_factor_x
142+ (), minval = self ._min_x_factor , maxval = self ._max_x_factor
149143 )
150144 random_scaling_y = self ._random_generator .random_uniform (
151- (), minval = self ._min_scaling_factor_y , maxval = self ._max_scaling_factor_y
145+ (), minval = self ._min_y_factor , maxval = self ._max_y_factor
152146 )
153147 random_scaling_z = self ._random_generator .random_uniform (
154- (), minval = self ._min_scaling_factor_z , maxval = self ._max_scaling_factor_z
148+ (), minval = self ._min_z_factor , maxval = self ._max_z_factor
155149 )
156- if not self ._same_scaling_xyz :
150+ if not self ._preserve_aspect_ratio :
157151 return {
158152 "scale" : tf .stack (
159153 [random_scaling_x , random_scaling_y , random_scaling_z ]
0 commit comments