Skip to content

Commit 20f16bb

Browse files
saberkunallenwang28
authored andcommitted
Internal change
PiperOrigin-RevId: 303780351
1 parent 4979bf2 commit 20f16bb

15 files changed

+18
-27
lines changed

official/vision/image_classification/augment.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from __future__ import print_function
2525

2626
import math
27-
import tensorflow.compat.v2 as tf
27+
import tensorflow as tf
2828
from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union
2929

3030
from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
@@ -75,8 +75,7 @@ def from_4d(image: tf.Tensor, ndims: int) -> tf.Tensor:
7575
return tf.reshape(image, new_shape)
7676

7777

78-
def _convert_translation_to_transform(
79-
translations: Iterable[int]) -> tf.Tensor:
78+
def _convert_translation_to_transform(translations) -> tf.Tensor:
8079
"""Converts translations to a projective transform.
8180
8281
The translation matrix looks like this:
@@ -166,8 +165,7 @@ def _convert_angles_to_transform(
166165
)
167166

168167

169-
def transform(image: tf.Tensor,
170-
transforms: Iterable[float]) -> tf.Tensor:
168+
def transform(image: tf.Tensor, transforms) -> tf.Tensor:
171169
"""Prepares input data for `image_ops.transform`."""
172170
original_ndims = tf.rank(image)
173171
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
@@ -181,8 +179,7 @@ def transform(image: tf.Tensor,
181179
return from_4d(image, original_ndims)
182180

183181

184-
def translate(image: tf.Tensor,
185-
translations: Iterable[int]) -> tf.Tensor:
182+
def translate(image: tf.Tensor, translations) -> tf.Tensor:
186183
"""Translates image(s) by provided vectors.
187184
188185
Args:
@@ -577,7 +574,7 @@ def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
577574
return image
578575

579576

580-
def _randomly_negate_tensor(tensor: tf.Tensor) -> tf.Tensor:
577+
def _randomly_negate_tensor(tensor):
581578
"""With 50% prob turn the tensor negative."""
582579
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
583580
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)

official/vision/image_classification/augment_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from absl.testing import parameterized
2323

24-
import tensorflow.compat.v2 as tf
24+
import tensorflow as tf
2525

2626
from official.vision.image_classification import augment
2727

@@ -133,5 +133,4 @@ def test_all_policy_ops(self):
133133
self.assertEqual((224, 224, 3), image.shape)
134134

135135
if __name__ == '__main__':
136-
assert tf.version.VERSION.startswith('2.')
137136
tf.test.main()

official/vision/image_classification/classifier_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from absl import app
2828
from absl import flags
2929
from absl import logging
30-
import tensorflow.compat.v2 as tf
30+
import tensorflow as tf
3131

3232
from official.modeling import performance
3333
from official.modeling.hyperparams import params_dict
@@ -423,5 +423,4 @@ def main(_):
423423
flags.mark_flag_as_required('model_type')
424424
flags.mark_flag_as_required('dataset')
425425

426-
assert tf.version.VERSION.startswith('2.')
427426
app.run(main)

official/vision/image_classification/classifier_trainer_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from absl import flags
3232
from absl.testing import parameterized
33-
import tensorflow.compat.v2 as tf
33+
import tensorflow as tf
3434

3535
from tensorflow.python.distribute import combinations
3636
from tensorflow.python.distribute import strategy_combinations
@@ -313,5 +313,4 @@ def test_serialize_config(self):
313313
tf.io.gfile.rmtree(model_dir)
314314

315315
if __name__ == '__main__':
316-
assert tf.version.VERSION.startswith('2.')
317316
tf.test.main()

official/vision/image_classification/dataset_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Any, List, Optional, Tuple, Mapping, Union
2424
from absl import logging
2525
from dataclasses import dataclass
26-
import tensorflow.compat.v2 as tf
26+
import tensorflow as tf
2727
import tensorflow_datasets as tfds
2828

2929
from official.modeling.hyperparams import base_config

official/vision/image_classification/efficientnet/common_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222

2323
import tensorflow.compat.v1 as tf1
24-
import tensorflow.compat.v2 as tf
24+
import tensorflow as tf
2525
from typing import Text, Optional
2626

2727
from tensorflow.python.tpu import tpu_function

official/vision/image_classification/efficientnet/efficientnet_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from absl import logging
3232
from dataclasses import dataclass
33-
import tensorflow.compat.v2 as tf
33+
import tensorflow as tf
3434

3535
from official.modeling import tf_utils
3636
from official.modeling.hyperparams import base_config

official/vision/image_classification/learning_rate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from typing import Any, List, Mapping
2222

23-
import tensorflow.compat.v2 as tf
23+
import tensorflow as tf
2424

2525
BASE_LEARNING_RATE = 0.1
2626

official/vision/image_classification/learning_rate_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import tensorflow.compat.v2 as tf
21+
import tensorflow as tf
2222

2323
from official.vision.image_classification import learning_rate
2424

@@ -86,5 +86,4 @@ def test_piecewise_constant_decay_invalid_boundaries(self):
8686

8787

8888
if __name__ == '__main__':
89-
assert tf.version.VERSION.startswith('2.')
9089
tf.test.main()

official/vision/image_classification/mnist_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,4 @@ def test_end_to_end(self, distribution):
8888

8989

9090
if __name__ == "__main__":
91-
tf.compat.v1.enable_v2_behavior()
9291
tf.test.main()

0 commit comments

Comments
 (0)