Skip to content

Commit 5456312

Browse files
dayeongltensorflower-gardener
authored andcommitted
Update Transform class comment.
PiperOrigin-RevId: 409877304
1 parent f13c253 commit 5456312

File tree

1 file changed

+55
-11
lines changed

1 file changed

+55
-11
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ def _normalize_tuple(value):
102102

103103

104104
class Conv2DBatchNormQuantize(transforms.Transform):
105-
"""Ensure FQ does not get placed between Conv and BatchNorm."""
105+
"""Transform to be applied to "Conv2D" + "BatchNorm" Graph.
106+
107+
This transform disables Quantization between Conv and BatchNorm
108+
to ensure FQ does not get placed between them.
109+
"""
106110

107111
def pattern(self):
108112
return LayerPattern(
@@ -135,7 +139,11 @@ def custom_objects(self):
135139

136140

137141
class Conv2DReshapeBatchNormQuantize(Conv2DBatchNormQuantize):
138-
"""Ensure FQ does not get placed between Conv, Reshape and BatchNorm."""
142+
"""Transform to be applied to "Conv2D" + "Reshape" + "BatchNorm" Graph.
143+
144+
This transform disables Quantization between Conv, Reshape and BatchNorm
145+
to ensure FQ does not get placed between them.
146+
"""
139147

140148
def pattern(self):
141149
return LayerPattern(
@@ -155,7 +163,11 @@ def replacement(self, match_layer):
155163

156164

157165
class Conv2DBatchNormReLUQuantize(Conv2DBatchNormQuantize):
158-
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
166+
"""Transform to be applied to "Conv2D" + "BatchNorm" + "ReLU" Graph.
167+
168+
This transform disables Quantization between Conv, BatchNorm and ReLU
169+
to ensure FQ does not get placed between them.
170+
"""
159171

160172
def pattern(self):
161173
return LayerPattern(
@@ -184,7 +196,11 @@ def replacement(self, match_layer):
184196

185197

186198
class Conv2DBatchNormActivationQuantize(Conv2DBatchNormReLUQuantize):
187-
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
199+
"""Transform to be applied to "Conv2D" + "BatchNorm" + "ReLU" Graph.
200+
201+
This transform disables Quantization between Conv, BatchNorm and ReLU
202+
to ensure FQ does not get placed between them.
203+
"""
188204

189205
def pattern(self):
190206
return LayerPattern(
@@ -194,7 +210,11 @@ def pattern(self):
194210

195211

196212
class Conv2DReshapeBatchNormReLUQuantize(Conv2DBatchNormReLUQuantize):
197-
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
213+
"""Transform to be applied to "Conv2D" + "Reshape" + "BatchNorm" + "ReLU" Graph.
214+
215+
This transform disables Quantization between Conv, Reshape, BatchNorm and ReLU
216+
to ensure FQ does not get placed between them.
217+
"""
198218

199219
def pattern(self):
200220
return LayerPattern(
@@ -212,7 +232,11 @@ def replacement(self, match_layer):
212232

213233
class Conv2DReshapeBatchNormActivationQuantize(
214234
Conv2DReshapeBatchNormReLUQuantize):
215-
"""Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""
235+
"""Transform to be applied to "Conv2D" + "Reshape" + "BatchNorm" + "ReLU" Graph.
236+
237+
This transform disables Quantization between Conv, Reshape, BatchNorm and ReLU
238+
to ensure FQ does not get placed between them.
239+
"""
216240

217241
def pattern(self):
218242
return LayerPattern(
@@ -222,7 +246,11 @@ def pattern(self):
222246

223247

224248
class DenseBatchNormQuantize(transforms.Transform):
225-
"""Ensure FQ does not get placed between Dense and BatchNorm."""
249+
"""Transform to be applied to "Dense"+ "BatchNorm" Graph.
250+
251+
This transform disables Quantization between Dense and BatchNorm
252+
to ensure FQ does not get placed between them.
253+
"""
226254

227255
def pattern(self):
228256
return LayerPattern(
@@ -254,7 +282,11 @@ def custom_objects(self):
254282

255283

256284
class DenseBatchNormReLUQuantize(DenseBatchNormQuantize):
257-
"""Ensure FQ does not get placed between Dense, BatchNorm and ReLU."""
285+
"""Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
286+
287+
This transform disables Quantization between Dense, BatchNorm and ReLU
288+
to ensure FQ does not get placed between them.
289+
"""
258290

259291
def pattern(self):
260292
return LayerPattern(
@@ -281,7 +313,11 @@ def replacement(self, match_layer):
281313

282314

283315
class DenseBatchNormActivationQuantize(DenseBatchNormReLUQuantize):
284-
"""Ensure FQ does not get placed between Dense, BatchNorm and ReLU."""
316+
"""Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.
317+
318+
This transform disables Quantization between Dense, BatchNorm and ReLU
319+
to ensure FQ does not get placed between them.
320+
"""
285321

286322
def pattern(self):
287323
return LayerPattern(
@@ -501,7 +537,11 @@ def replacement(self, match_layer):
501537

502538

503539
class LayerReLUQuantize(transforms.Transform):
504-
"""Ensure FQ does not get placed between Add and ReLU."""
540+
"""Transform to be applied to "Add"+ "ReLU" Graph.
541+
542+
This transform disables Quantization between Add and ReLU
543+
to ensure FQ does not get placed between them.
544+
"""
505545

506546
def pattern(self):
507547
return LayerPattern(
@@ -523,7 +563,11 @@ def custom_objects(self):
523563

524564

525565
class LayerReluActivationQuantize(LayerReLUQuantize):
526-
"""Ensure FQ does not get placed between Add and ReLU."""
566+
"""Transform to be applied to "Add"+ "ReLU" Graph.
567+
568+
This transform disables Quantization between Add and ReLU
569+
to ensure FQ does not get placed between them.
570+
"""
527571

528572
def pattern(self):
529573
return LayerPattern(

0 commit comments

Comments
 (0)