|
| 1 | +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Utility functions for comparing NestedMap in Python. |
| 16 | +
|
| 17 | +When comparing NestedMap variables, instead of getting this type of cryptic |
| 18 | +error message in your unit test: |
| 19 | +
|
| 20 | +self.assertEqual(expected, actual) |
| 21 | +> AssertionError: {'src[35 chars]ape=(16, 512), dtype=float32), |
| 22 | +> 'src_inputs': S[305 chars]32)}} != {'src[35 chars]ape=(8, 512), |
| 23 | +> dtype=float32), 'src_inputs': Sh[299 chars]32)}} |
| 24 | +
|
| 25 | +You get: |
| 26 | +self.assertNestedMapEqual(expected, actual) |
| 27 | +> AssertionError: |
| 28 | +- src.paddings ShapeDtypeStruct(shape=(16, 512), dtype=float32) |
| 29 | +? ^^ |
| 30 | ++ src.paddings ShapeDtypeStruct(shape=(8, 512), dtype=float32) |
| 31 | +? ^ |
| 32 | +- src.src_inputs ShapeDtypeStruct(shape=(16, 512, 240, 1), dtype=float32) |
| 33 | +? ^^ |
| 34 | ++ src.src_inputs ShapeDtypeStruct(shape=(8, 512, 240, 1), dtype=float32) |
| 35 | +? ^ |
| 36 | +- src.video ShapeDtypeStruct(shape=(16, 512, 128, 128), dtype=float32) |
| 37 | +? ^^ |
| 38 | ++ src.video ShapeDtypeStruct(shape=(8, 512, 128, 128), dtype=float32) |
| 39 | +? ^ |
| 40 | +- tgt.ids ShapeDtypeStruct(shape=(16, 128), dtype=int32) |
| 41 | +? ^^ |
| 42 | ++ tgt.ids ShapeDtypeStruct(shape=(8, 128), dtype=int32) |
| 43 | +? ^ |
| 44 | +- tgt.labels ShapeDtypeStruct(shape=(16, 128), dtype=int32) |
| 45 | +? ^^ |
| 46 | ++ tgt.labels ShapeDtypeStruct(shape=(8, 128), dtype=int32) |
| 47 | +? ^ |
| 48 | +- tgt.paddings ShapeDtypeStruct(shape=(16, 128), dtype=float32) |
| 49 | +? ^^ |
| 50 | ++ tgt.paddings ShapeDtypeStruct(shape=(8, 128), dtype=float32) |
| 51 | +? |
| 52 | +""" |
| 53 | + |
| 54 | +from typing import Any, Union |
| 55 | +import unittest |
| 56 | +from lingvo.core import py_utils |
| 57 | + |
| 58 | + |
| 59 | +# pyformat: disable |
| 60 | +def assertNestedMapEqual( # pylint: disable=invalid-name |
| 61 | + self: unittest.TestCase, |
| 62 | + expected: Union[dict[str, Any], py_utils.NestedMap], |
| 63 | + actual: py_utils.NestedMap): |
| 64 | + |
| 65 | + if not hasattr(expected, 'DebugString'): |
| 66 | + expected = py_utils.NestedMap(expected) |
| 67 | + |
| 68 | + self.assertMultiLineEqual(expected.DebugString(), actual.DebugString()) |
| 69 | +# pyformat: enable |
| 70 | + |
| 71 | + |
| 72 | +class NestedMapAssertions(unittest.TestCase): |
| 73 | + """Mix this into a googletest.TestCase class to get NestedMap asserts. |
| 74 | +
|
| 75 | + Usage: |
| 76 | +
|
| 77 | + class SomeTestCase(compare.NestedMapAssertions, googletest.TestCase): |
| 78 | + ... |
| 79 | + def testSomething(self): |
| 80 | + ... |
| 81 | + self.assertNestedMapEqual(expected, actual): |
| 82 | + """ |
| 83 | + |
| 84 | + # pylint: disable=invalid-name |
| 85 | + def assertNestedMapEqual(self, *args, **kwargs): |
| 86 | + return assertNestedMapEqual(self, *args, **kwargs) |
0 commit comments