Skip to content

Commit 0e43ced

Browse files
lingvo-botcopybara-github
authored andcommitted
Add a NestedMap compare using multi-line string compare for unittests.
PiperOrigin-RevId: 491355192
1 parent 3e19cc5 commit 0e43ced

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

lingvo/core/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,12 @@ py_test(
664664
],
665665
)
666666

667+
pytype_strict_library(
668+
name = "compare",
669+
srcs = ["compare.py"],
670+
deps = [":py_utils"],
671+
)
672+
667673
py_library(
668674
name = "datasource",
669675
srcs = ["datasource.py"],

lingvo/core/compare.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utility functions for comparing NestedMap in Python.
16+
17+
When comparing NestedMap variables, instead of getting this type of cryptic
18+
error message in your unit test:
19+
20+
self.assertEqual(expected, actual)
21+
> AssertionError: {'src[35 chars]ape=(16, 512), dtype=float32),
22+
> 'src_inputs': S[305 chars]32)}} != {'src[35 chars]ape=(8, 512),
23+
> dtype=float32), 'src_inputs': Sh[299 chars]32)}}
24+
25+
You get:
26+
self.assertNestedMapEqual(expected, actual)
27+
> AssertionError:
28+
- src.paddings ShapeDtypeStruct(shape=(16, 512), dtype=float32)
29+
? ^^
30+
+ src.paddings ShapeDtypeStruct(shape=(8, 512), dtype=float32)
31+
? ^
32+
- src.src_inputs ShapeDtypeStruct(shape=(16, 512, 240, 1), dtype=float32)
33+
? ^^
34+
+ src.src_inputs ShapeDtypeStruct(shape=(8, 512, 240, 1), dtype=float32)
35+
? ^
36+
- src.video ShapeDtypeStruct(shape=(16, 512, 128, 128), dtype=float32)
37+
? ^^
38+
+ src.video ShapeDtypeStruct(shape=(8, 512, 128, 128), dtype=float32)
39+
? ^
40+
- tgt.ids ShapeDtypeStruct(shape=(16, 128), dtype=int32)
41+
? ^^
42+
+ tgt.ids ShapeDtypeStruct(shape=(8, 128), dtype=int32)
43+
? ^
44+
- tgt.labels ShapeDtypeStruct(shape=(16, 128), dtype=int32)
45+
? ^^
46+
+ tgt.labels ShapeDtypeStruct(shape=(8, 128), dtype=int32)
47+
? ^
48+
- tgt.paddings ShapeDtypeStruct(shape=(16, 128), dtype=float32)
49+
? ^^
50+
+ tgt.paddings ShapeDtypeStruct(shape=(8, 128), dtype=float32)
51+
?
52+
"""
53+
54+
from typing import Any, Union
55+
import unittest
56+
from lingvo.core import py_utils
57+
58+
59+
# pyformat: disable
60+
def assertNestedMapEqual( # pylint: disable=invalid-name
61+
self: unittest.TestCase,
62+
expected: Union[dict[str, Any], py_utils.NestedMap],
63+
actual: py_utils.NestedMap):
64+
65+
if not hasattr(expected, 'DebugString'):
66+
expected = py_utils.NestedMap(expected)
67+
68+
self.assertMultiLineEqual(expected.DebugString(), actual.DebugString())
69+
# pyformat: enable
70+
71+
72+
class NestedMapAssertions(unittest.TestCase):
73+
"""Mix this into a googletest.TestCase class to get NestedMap asserts.
74+
75+
Usage:
76+
77+
class SomeTestCase(compare.NestedMapAssertions, googletest.TestCase):
78+
...
79+
def testSomething(self):
80+
...
81+
self.assertNestedMapEqual(expected, actual):
82+
"""
83+
84+
# pylint: disable=invalid-name
85+
def assertNestedMapEqual(self, *args, **kwargs):
86+
return assertNestedMapEqual(self, *args, **kwargs)

0 commit comments

Comments
 (0)