Skip to content

Commit 847489b

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Update LinearOperators with rewrite_equivalence_test, update Numpy backend after changes to LinOp with CompositeTensor.
PiperOrigin-RevId: 375142062
1 parent a163a15 commit 847489b

26 files changed

+426
-10
lines changed

tensorflow_probability/python/internal/backend/jax/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ FILENAMES = [
3030
"composite_tensor",
3131
"config",
3232
"control_flow",
33+
"data_structures",
3334
"debugging",
3435
"deprecation",
3536
"dtype",
@@ -58,9 +59,11 @@ FILENAMES = [
5859
"sparse_lib",
5960
"tensor_array_ops",
6061
"tensor_array_ops_test",
62+
"tensor_spec",
6163
"test_lib",
6264
"tf_inspect",
6365
"type_spec",
66+
"variables",
6467
"v1",
6568
"v2",
6669
"_utils",

tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
'from tensorflow.python.framework import tensor_util',
5959
'@tf_export',
6060
'@dispatch',
61+
'@linear_operator.make_composite_tensor',
6162
'self._check_input_dtype',
6263
]
6364

@@ -111,6 +112,33 @@ def gen_module(module_name):
111112
code = code.replace(
112113
'from tensorflow.python.platform import tf_logging',
113114
'from absl import logging')
115+
code = code.replace(
116+
'from tensorflow.python.framework import '
117+
'composite_tensor',
118+
'from tensorflow_probability.python.internal.backend.numpy '
119+
'import composite_tensor')
120+
code = code.replace(
121+
'from tensorflow.python.ops import '
122+
'resource_variable_ops',
123+
'from tensorflow_probability.python.internal.backend.numpy '
124+
'import resource_variable_ops')
125+
code = code.replace(
126+
'from tensorflow.python.framework import tensor_spec',
127+
'from tensorflow_probability.python.internal.backend.numpy import '
128+
'tensor_spec')
129+
code = code.replace(
130+
'from tensorflow.python.framework import type_spec',
131+
'from tensorflow_probability.python.internal.backend.numpy '
132+
'import type_spec')
133+
code = code.replace(
134+
'from tensorflow.python.ops import variables',
135+
'from tensorflow_probability.python.internal.backend.numpy '
136+
'import variables')
137+
code = code.replace(
138+
'from tensorflow.python.training.tracking '
139+
'import data_structures',
140+
'from tensorflow_probability.python.internal.backend.numpy '
141+
'import data_structures')
114142
code = re.sub(
115143
r'from tensorflow\.python\.linalg import (\w+)',
116144
'from tensorflow_probability.python.internal.backend.numpy.gen import \\1 '

tensorflow_probability/python/internal/backend/numpy/BUILD

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ py_library(
3333
":composite_tensor",
3434
":config",
3535
":control_flow",
36+
":data_structures",
3637
":debugging",
3738
":deprecation",
3839
":dtype",
@@ -57,9 +58,11 @@ py_library(
5758
":sparse_lib",
5859
":static_rewrites",
5960
":tensor_array_ops",
61+
":tensor_spec",
6062
":test_lib",
6163
":tf_inspect",
6264
":type_spec",
65+
":variables",
6366
],
6467
)
6568

@@ -110,6 +113,12 @@ py_library(
110113
],
111114
)
112115

116+
py_library(
117+
name = "data_structures",
118+
srcs = ["data_structures.py"],
119+
deps = [],
120+
)
121+
113122
py_library(
114123
name = "debugging",
115124
srcs = ["debugging.py"],
@@ -389,11 +398,24 @@ py_library(
389398
deps = [],
390399
)
391400

401+
py_library(
402+
name = "tensor_spec",
403+
srcs = ["tensor_spec.py"],
404+
)
405+
392406
py_library(
393407
name = "type_spec",
394408
srcs = ["type_spec.py"],
395409
)
396410

411+
py_library(
412+
name = "variables",
413+
srcs = ["variables.py"],
414+
deps = [
415+
":ops",
416+
],
417+
)
418+
397419
py_library(
398420
name = "numpy_testlib",
399421
testonly = 1,

tensorflow_probability/python/internal/backend/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from tensorflow_probability.python.internal.backend.numpy.numpy_array import * # pylint: disable=wildcard-import
4747
from tensorflow_probability.python.internal.backend.numpy.numpy_math import * # pylint: disable=wildcard-import
4848
from tensorflow_probability.python.internal.backend.numpy.ops import * # pylint: disable=wildcard-import
49+
from tensorflow_probability.python.internal.backend.numpy.tensor_spec import TensorSpec
4950
from tensorflow_probability.python.internal.backend.numpy.type_spec import BatchableTypeSpec
5051
from tensorflow_probability.python.internal.backend.numpy.type_spec import TypeSpec
5152

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Numpy stub for `data_structures`."""
16+
17+
18+
__all__ = ['TrackableDataStructure']
19+
20+
21+
class TrackableDataStructure(object):
22+
pass

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,28 @@
4343
import numpy as np
4444
import six
4545

46+
from tensorflow_probability.python.internal.backend.numpy import composite_tensor
4647
from tensorflow_probability.python.internal.backend.numpy import dtype as dtypes
4748
from tensorflow_probability.python.internal.backend.numpy import ops
4849
from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape
50+
from tensorflow_probability.python.internal.backend.numpy import tensor_spec
4951
# from tensorflow.python.framework import tensor_util
52+
from tensorflow_probability.python.internal.backend.numpy import type_spec
5053
from tensorflow_probability.python.internal.backend.numpy import ops as module
5154
from tensorflow_probability.python.internal.backend.numpy import numpy_array as array_ops
5255
from tensorflow_probability.python.internal.backend.numpy import debugging as check_ops
5356
from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg_ops
5457
from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops
58+
from tensorflow_probability.python.internal.backend.numpy import resource_variable_ops
59+
from tensorflow_probability.python.internal.backend.numpy import variables
5560
from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg
5661
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra
5762
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util
5863
from absl import logging as logging
64+
from tensorflow_probability.python.internal.backend.numpy import data_structures
5965
from tensorflow_probability.python.internal.backend.numpy import deprecation
6066
# from tensorflow_probability.python.internal.backend.numpy import dispatch
67+
from tensorflow_probability.python.internal.backend.numpy import nest
6168
# from tensorflow.python.util.tf_export import tf_export
6269

6370
__all__ = ["LinearOperator"]
@@ -66,7 +73,7 @@
6673
# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
6774
# @tf_export("linalg.LinearOperator")
6875
@six.add_metaclass(abc.ABCMeta)
69-
class LinearOperator(module.Module):
76+
class LinearOperator(module.Module, composite_tensor.CompositeTensor):
7077
"""Base class defining a [batch of] linear operator[s].
7178
7279
Subclasses of `LinearOperator` provide access to common methods on a
@@ -1176,6 +1183,201 @@ def _set_graph_parents(self, graph_parents):
11761183
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
11771184
self._graph_parents = graph_parents
11781185

1186+
@property
1187+
def _composite_tensor_fields(self):
1188+
"""A tuple of parameter names to rebuild the `LinearOperator`.
1189+
1190+
The tuple contains the names of kwargs to the `LinearOperator`'s constructor
1191+
that the `TypeSpec` needs to rebuild the `LinearOperator` instance.
1192+
1193+
"is_non_singular", "is_self_adjoint", "is_positive_definite", and
1194+
"is_square" are common to all `LinearOperator` subclasses and may be
1195+
omitted.
1196+
"""
1197+
return ()
1198+
1199+
@property
1200+
def _composite_tensor_prefer_static_fields(self):
1201+
"""A tuple of names referring to parameters that may be treated statically.
1202+
1203+
This is a subset of `_composite_tensor_fields`, and contains the names of
1204+
of `Tensor`-like args to the `LinearOperator`s constructor that may be
1205+
stored as static values, if they are statically known. These are typically
1206+
shapes or axis values.
1207+
"""
1208+
return ()
1209+
1210+
@property
1211+
def _type_spec(self):
1212+
# This property will be overwritten by the `@make_composite_tensor`
1213+
# decorator. However, we need it so that a valid subclass of the `ABCMeta`
1214+
# class `CompositeTensor` can be constructed and passed to the
1215+
# `@make_composite_tensor` decorator.
1216+
pass
1217+
1218+
1219+
class _LinearOperatorSpec(type_spec.TypeSpec):
1220+
"""A tf.TypeSpec for `LinearOperator` objects."""
1221+
1222+
__slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields")
1223+
1224+
def __init__(self, param_specs, non_tensor_params, prefer_static_fields):
1225+
"""Initializes a new `_LinearOperatorSpec`.
1226+
1227+
Args:
1228+
param_specs: Python `dict` of `tf.TypeSpec` instances that describe
1229+
kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or
1230+
`CompositeTensor` subclasses.
1231+
non_tensor_params: Python `dict` containing non-`Tensor` and non-
1232+
`CompositeTensor` kwargs to the `LinearOperator`'s constructor.
1233+
prefer_static_fields: Python `tuple` of strings corresponding to the names
1234+
of `Tensor`-like args to the `LinearOperator`s constructor that may be
1235+
stored as static values, if known. These are typically shapes, indices,
1236+
or axis values.
1237+
"""
1238+
self._param_specs = param_specs
1239+
self._non_tensor_params = non_tensor_params
1240+
self._prefer_static_fields = prefer_static_fields
1241+
1242+
@classmethod
1243+
def from_operator(cls, operator):
1244+
"""Builds a `_LinearOperatorSpec` from a `LinearOperator` instance.
1245+
1246+
Args:
1247+
operator: An instance of `LinearOperator`.
1248+
1249+
Returns:
1250+
linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as
1251+
the `TypeSpec` of `operator`.
1252+
"""
1253+
validation_fields = ("is_non_singular", "is_self_adjoint",
1254+
"is_positive_definite", "is_square")
1255+
kwargs = _extract_attrs(
1256+
operator,
1257+
keys=set(operator._composite_tensor_fields + validation_fields)) # pylint: disable=protected-access
1258+
1259+
non_tensor_params = {}
1260+
param_specs = {}
1261+
for k, v in list(kwargs.items()):
1262+
type_spec_or_v = _extract_type_spec_recursively(v)
1263+
is_tensor = [isinstance(x, type_spec.TypeSpec)
1264+
for x in nest.flatten(type_spec_or_v)]
1265+
if all(is_tensor):
1266+
param_specs[k] = type_spec_or_v
1267+
elif not any(is_tensor):
1268+
non_tensor_params[k] = v
1269+
else:
1270+
raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and "
1271+
f" non-`Tensor` values.")
1272+
1273+
return cls(
1274+
param_specs=param_specs,
1275+
non_tensor_params=non_tensor_params,
1276+
prefer_static_fields=operator._composite_tensor_prefer_static_fields) # pylint: disable=protected-access
1277+
1278+
def _to_components(self, obj):
1279+
return _extract_attrs(obj, keys=list(self._param_specs))
1280+
1281+
def _from_components(self, components):
1282+
kwargs = dict(self._non_tensor_params, **components)
1283+
return self.value_type(**kwargs)
1284+
1285+
@property
1286+
def _component_specs(self):
1287+
return self._param_specs
1288+
1289+
def _serialize(self):
1290+
return (self._param_specs,
1291+
self._non_tensor_params,
1292+
self._prefer_static_fields)
1293+
1294+
1295+
def make_composite_tensor(cls, module_name="tf.linalg"):
1296+
"""Class decorator to convert `LinearOperator`s to `CompositeTensor`."""
1297+
1298+
spec_name = "{}Spec".format(cls.__name__)
1299+
spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls})
1300+
type_spec.register("{}.{}".format(module_name, spec_name))(spec_type)
1301+
cls._type_spec = property(spec_type.from_operator) # pylint: disable=protected-access
1302+
return cls
1303+
1304+
1305+
def _extract_attrs(op, keys):
1306+
"""Extract constructor kwargs to reconstruct `op`.
1307+
1308+
Args:
1309+
op: A `LinearOperator` instance.
1310+
keys: A Python `tuple` of strings indicating the names of the constructor
1311+
kwargs to extract from `op`.
1312+
1313+
Returns:
1314+
kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`.
1315+
"""
1316+
1317+
kwargs = {}
1318+
not_found = object()
1319+
for k in keys:
1320+
srcs = [
1321+
getattr(op, k, not_found), getattr(op, "_" + k, not_found),
1322+
getattr(op, "parameters", {}).get(k, not_found),
1323+
]
1324+
if any(v is not not_found for v in srcs):
1325+
kwargs[k] = [v for v in srcs if v is not not_found][0]
1326+
else:
1327+
raise ValueError(
1328+
f"Could not determine an appropriate value for field `{k}` in object "
1329+
f" `{op}`. Looked for \n"
1330+
f" 1. an attr called `{k}`,\n"
1331+
f" 2. an attr called `_{k}`,\n"
1332+
f" 3. an entry in `op.parameters` with key '{k}'.")
1333+
if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None: # pylint: disable=protected-access
1334+
if ops.is_tensor(kwargs[k]):
1335+
static_val = (kwargs[k])
1336+
if static_val is not None:
1337+
kwargs[k] = static_val
1338+
if isinstance(kwargs[k], (np.ndarray, np.generic)):
1339+
kwargs[k] = kwargs[k].tolist()
1340+
return kwargs
1341+
1342+
1343+
def _extract_type_spec_recursively(value):
1344+
"""Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s.
1345+
1346+
If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If
1347+
`value` is a collection containing `Tensor` values, recursively supplant them
1348+
with their respective `TypeSpec`s in a collection of parallel stucture.
1349+
1350+
If `value` is none of the above, return it unchanged.
1351+
1352+
Args:
1353+
value: a Python `object` to (possibly) turn into a (collection of)
1354+
`tf.TypeSpec`(s).
1355+
1356+
Returns:
1357+
spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`
1358+
or `value`, if no `Tensor`s are found.
1359+
"""
1360+
if isinstance(value, composite_tensor.CompositeTensor):
1361+
return value._type_spec # pylint: disable=protected-access
1362+
if isinstance(value, variables.Variable):
1363+
return resource_variable_ops.VariableSpec(
1364+
tensor_shape.TensorShape(value.shape), dtype=value.dtype, trainable=value.trainable)
1365+
if ops.is_tensor(value):
1366+
return tensor_spec.TensorSpec(tensor_shape.TensorShape(value.shape), value.dtype)
1367+
# Unwrap trackable data structures to comply with `Type_Spec._serialize`
1368+
# requirements. `ListWrapper`s are converted to `list`s, and for other
1369+
# trackable data structures, the `__wrapped__` attribute is used.
1370+
if isinstance(value, list):
1371+
return list(_extract_type_spec_recursively(v) for v in value)
1372+
if isinstance(value, data_structures.TrackableDataStructure):
1373+
return _extract_type_spec_recursively(value.__wrapped__)
1374+
if isinstance(value, tuple):
1375+
return type(value)(_extract_type_spec_recursively(x) for x in value)
1376+
if isinstance(value, dict):
1377+
return type(value)((k, _extract_type_spec_recursively(v))
1378+
for k, v in value.items())
1379+
return value
1380+
11791381

11801382
# Overrides for tf.linalg functions. This allows a LinearOperator to be used in
11811383
# place of a Tensor.

0 commit comments

Comments
 (0)