Skip to content

Commit f296b81

Browse files
lingvo-botcopybara-github
authored andcommitted
Improve DefineAndTrace type signature; lingvo 0.12.5 release
PiperOrigin-RevId: 488677433
1 parent 9150f56 commit f296b81

File tree

5 files changed

+100
-17
lines changed

5 files changed

+100
-17
lines changed

lingvo/core/BUILD

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ load(
77
"lingvo_proto_py",
88
"lingvo_py_binary",
99
)
10-
load("//lingvo:lingvo.bzl", "pytype_library")
10+
load("//lingvo:lingvo.bzl", "pytype_library", "pytype_strict_library", "pytype_strict_test")
1111

1212
package(
1313
default_visibility = ["//visibility:public"],
@@ -1414,20 +1414,28 @@ py_library(
14141414
],
14151415
)
14161416

1417-
py_test(
1417+
pytype_strict_test(
14181418
name = "py_utils_test",
14191419
size = "medium",
14201420
srcs = ["py_utils_test.py"],
14211421
args = ["--noenable_eager_execution"],
1422-
deps = [":py_utils_test_lib"],
1422+
deps = [
1423+
":py_utils_test_lib",
1424+
# Implicit tensorflow functional_ops dependency.
1425+
# Implicit tensorflow init_ops dependency.
1426+
],
14231427
)
14241428

1425-
py_test(
1429+
pytype_strict_test(
14261430
name = "py_utils_eager_test",
14271431
srcs = ["py_utils_test.py"],
14281432
args = ["--enable_eager_execution"],
14291433
main = "py_utils_test.py",
1430-
deps = [":py_utils_test_lib"],
1434+
deps = [
1435+
":py_utils_test_lib",
1436+
# Implicit tensorflow functional_ops dependency.
1437+
# Implicit tensorflow init_ops dependency.
1438+
],
14311439
)
14321440

14331441
pytype_library(
@@ -1788,19 +1796,20 @@ py_library(
17881796
],
17891797
)
17901798

1791-
py_library(
1799+
pytype_strict_library(
17921800
name = "test_utils",
17931801
srcs = ["test_utils.py"],
17941802
deps = [
17951803
":cluster_factory",
17961804
":py_utils",
1805+
":pytypes",
17971806
"//lingvo:compat",
17981807
# Implicit numpy dependency.
17991808
# Implicit tensorboard/backend/event_processing:event_file_inspector dependency.
18001809
],
18011810
)
18021811

1803-
py_test(
1812+
pytype_strict_test(
18041813
name = "test_utils_eager_test",
18051814
srcs = ["test_utils_eager_test.py"],
18061815
args = ["--enable_eager_execution"],
@@ -1810,11 +1819,12 @@ py_test(
18101819
],
18111820
)
18121821

1813-
py_test(
1822+
pytype_strict_test(
18141823
name = "test_utils_test",
18151824
srcs = ["test_utils_test.py"],
18161825
args = ["--noenable_eager_execution"],
18171826
deps = [
1827+
":py_utils",
18181828
":test_utils",
18191829
"//lingvo:compat",
18201830
],
@@ -2600,3 +2610,14 @@ py_strict_test(
26002610
# Implicit numpy dependency.
26012611
],
26022612
)
2613+
2614+
pytype_strict_library(
2615+
name = "pytypes",
2616+
srcs = ["pytypes.py"],
2617+
deps = [
2618+
":hyperparams",
2619+
":nested_map",
2620+
":py_utils",
2621+
# Implicit numpy dependency.
2622+
],
2623+
)

lingvo/core/pytypes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""PyType Typing utilities for lingvo.
16+
17+
A subset of third_party/py/praxis/pytypes.py.
18+
"""
19+
20+
from typing import List, Tuple, TypeVar, Union, Mapping
21+
22+
from lingvo.core import hyperparams
23+
from lingvo.core import nested_map
24+
from lingvo.core import py_utils
25+
import numpy as np
26+
27+
NpTensor = np.ndarray
28+
29+
NestedMap = nested_map.NestedMap
30+
Params = hyperparams.Params
31+
InstantiableParams = hyperparams.InstantiableParams
32+
33+
T = TypeVar('T')
34+
Nested = Union[T, Tuple[T, ...], List[T], Mapping[str, T], py_utils.NestedMap]
35+
NestedBool = Nested[bool]
36+
NestedInt = Nested[int]

lingvo/core/test_utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import lingvo.compat as tf
2727
from lingvo.core import cluster_factory
2828
from lingvo.core import py_utils
29+
from lingvo.core import pytypes
2930
import numpy as np
3031

3132
from tensorboard.backend.event_processing import event_file_inspector
@@ -103,17 +104,40 @@ def gradient(self, ys, xs) -> List[tf.Tensor]: # pylint: disable=invalid-name
103104
return tf.gradients(ys=ys, xs=xs)
104105

105106

107+
Placeholder = typing.TypeVar('Placeholder', bound=tf.Tensor)
108+
RetT = typing.TypeVar('RetT')
109+
110+
111+
class OverloadedForEager(typing.Protocol, typing.Generic[RetT]):
112+
113+
def __call__(self, *args: pytypes.Nested[tf.Tensor]) -> RetT:
114+
"""Needed b/c *args typing support was only added in Py3.11."""
115+
...
116+
117+
118+
class OverloadedForGraph(typing.Protocol, typing.Generic[RetT]):
119+
120+
def __call__(self, *args: pytypes.Nested[Placeholder]) -> RetT:
121+
"""Needed b/c *args typing support was only added in Py3.11."""
122+
...
123+
124+
106125
@typing.overload
107126
def DefineAndTrace(
108-
*tensor_specs_or_placeholders: tf.TensorSpec
109-
) -> Callable[Callable, Callable]: # pylint: disable=g-bare-generic
127+
*tensor_specs_or_placeholders: pytypes.Nested[tf.TensorSpec]
128+
) -> Callable[[OverloadedForEager[RetT]],
129+
tf.types.experimental.ConcreteFunction]:
130+
"""Eager case."""
110131
...
111132

112133

113134
# TODO(jlipschultz): Consider changing the behavior of the graph-mode version of
114135
# DefineAndTrace to also return a Callable[Callable, Callable] for simplicity.
115136
@typing.overload
116-
def DefineAndTrace(*tensor_specs_or_placeholders: tf.Tensor) -> Callable: # pylint: disable=g-bare-generic
137+
def DefineAndTrace(
138+
*tensor_specs_or_placeholders: pytypes.Nested[Placeholder]
139+
) -> Callable[[OverloadedForGraph[RetT]], RetT]:
140+
"""Graph-mode case."""
117141
...
118142

119143

lingvo/lingvo.bzl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def lingvo_proto_py(name, src, deps = []):
173173
srcs = [basename + "_pb2.py"],
174174
)
175175

176-
# Placeholder to use until bazel supports pytype_library.
176+
# Placeholders to use until bazel supports pytype_{,strict_}{library,test,binary}.
177177
def pytype_library(name, **kwargs):
178178
native.py_library(name = name, **kwargs)
179179

@@ -183,13 +183,15 @@ def pytype_strict_library(name, **kwargs):
183183
def pytype_strict_binary(name, **kwargs):
184184
native.py_binary(name = name, **kwargs)
185185

186+
def py_strict_test(name, **kwargs):
187+
native.py_test(name = name, **kwargs)
188+
189+
def pytype_strict_test(name, **kwargs):
190+
native.py_test(name = name, **kwargs)
191+
186192
def lingvo_portable_pytype_library(name, deps = [], nonportable_deps = [], **kwargs):
187193
pytype_library(
188194
name = name,
189195
deps = deps + nonportable_deps,
190196
**kwargs
191197
)
192-
193-
# Placeholder to use until bazel supports py_strict_test.
194-
def py_strict_test(name, **kwargs):
195-
native.py_test(name = name, **kwargs)

pip_package/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from setuptools.command.install import install
2121
from setuptools.dist import Distribution
2222

23-
__version__ = '0.12.4'
23+
__version__ = '0.12.5'
2424
project_name = 'lingvo'
2525
if '--project_name' in sys.argv:
2626
project_name_idx = sys.argv.index('--project_name')

0 commit comments

Comments
 (0)