|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 |
|
17 | | -import contextlib |
18 | 17 | import itertools |
19 | 18 | import math |
20 | 19 | import os |
@@ -562,9 +561,6 @@ def preprocessing_fn(inputs): |
562 | 561 | expected_metadata) |
563 | 562 |
|
564 | 563 | def testPyFuncs(self): |
565 | | - if not tft_unit.is_tf_api_version_1(): |
566 | | - raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.') |
567 | | - |
568 | 564 | def my_multiply(x, y): |
569 | 565 | return x*y |
570 | 566 |
|
@@ -628,14 +624,11 @@ def preprocessing_fn(inputs): |
628 | 624 | }) |
629 | 625 | self.assertAnalyzeAndTransformResults( |
630 | 626 | input_data, input_metadata, preprocessing_fn, expected_data, |
631 | | - expected_metadata) |
| 627 | + expected_metadata, force_tf_compat_v1=True) |
632 | 628 |
|
633 | 629 | def testAssertsNoReturnPyFunc(self): |
634 | 630 | # Asserts that apply_pyfunc raises an exception if the passed function does |
635 | 631 | # not return anything. |
636 | | - if not tft_unit.is_tf_api_version_1(): |
637 | | - raise unittest.SkipTest('Test disabled when TF 2.x behavior enabled.') |
638 | | - |
639 | 632 | self._SkipIfOutputRecordBatches() |
640 | 633 |
|
641 | 634 | def bad_func(): |
@@ -684,7 +677,8 @@ def preprocessing_fn(inputs): |
684 | 677 | preprocessing_fn, |
685 | 678 | expected_data, |
686 | 679 | expected_metadata, |
687 | | - desired_batch_size=batch_size) |
| 680 | + desired_batch_size=batch_size, |
| 681 | + force_tf_compat_v1=True) |
688 | 682 |
|
689 | 683 | def testWithUnicode(self): |
690 | 684 | def preprocessing_fn(inputs): |
@@ -4714,12 +4708,6 @@ def testEmptySchema(self): |
4714 | 4708 | preprocessing_fn=lambda inputs: inputs) # pyformat: disable |
4715 | 4709 |
|
4716 | 4710 | def testLoadKerasModelInPreprocessingFn(self): |
4717 | | - |
4718 | | - if tft_unit.is_tf_api_version_1(): |
4719 | | - raise unittest.SkipTest( |
4720 | | - '`tft.make_and_track_object` is only supported when TF2 behavior is ' |
4721 | | - 'enabled.') |
4722 | | - |
4723 | 4711 | def _create_model(features, target): |
4724 | 4712 | inputs = [ |
4725 | 4713 | tf.keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features |
@@ -4797,11 +4785,8 @@ def preprocessing_fn(inputs): |
4797 | 4785 | 'f3': 1 |
4798 | 4786 | }] |
4799 | 4787 |
|
4800 | | - with contextlib.ExitStack() as stack: |
4801 | | - if not tft_unit.is_tf_api_version_1(): |
4802 | | - stack.enter_context( |
4803 | | - self.assertRaisesRegex( |
4804 | | - RuntimeError, 'analyzers.*appears to be non-deterministic')) |
| 4788 | + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises |
| 4789 | + RuntimeError, 'analyzers.*appears to be non-deterministic'): |
4805 | 4790 | self.assertAnalyzeAndTransformResults(input_data, input_metadata, |
4806 | 4791 | preprocessing_fn, expected_outputs) |
4807 | 4792 |
|
|
0 commit comments