|
43 | 43 | import numpy as np
|
44 | 44 | import six
|
45 | 45 |
|
| 46 | +from tensorflow_probability.python.internal.backend.numpy import composite_tensor |
46 | 47 | from tensorflow_probability.python.internal.backend.numpy import dtype as dtypes
|
47 | 48 | from tensorflow_probability.python.internal.backend.numpy import ops
|
48 | 49 | from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape
|
| 50 | +from tensorflow_probability.python.internal.backend.numpy import tensor_spec |
49 | 51 | # from tensorflow.python.framework import tensor_util
|
| 52 | +from tensorflow_probability.python.internal.backend.numpy import type_spec |
50 | 53 | from tensorflow_probability.python.internal.backend.numpy import ops as module
|
51 | 54 | from tensorflow_probability.python.internal.backend.numpy import numpy_array as array_ops
|
52 | 55 | from tensorflow_probability.python.internal.backend.numpy import debugging as check_ops
|
53 | 56 | from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg_ops
|
54 | 57 | from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops
|
| 58 | +from tensorflow_probability.python.internal.backend.numpy import resource_variable_ops |
| 59 | +from tensorflow_probability.python.internal.backend.numpy import variables |
55 | 60 | from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg
|
56 | 61 | from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra
|
57 | 62 | from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util
|
58 | 63 | from absl import logging as logging
|
| 64 | +from tensorflow_probability.python.internal.backend.numpy import data_structures |
59 | 65 | from tensorflow_probability.python.internal.backend.numpy import deprecation
|
60 | 66 | # from tensorflow_probability.python.internal.backend.numpy import dispatch
|
| 67 | +from tensorflow_probability.python.internal.backend.numpy import nest |
61 | 68 | # from tensorflow.python.util.tf_export import tf_export
|
62 | 69 |
|
63 | 70 | __all__ = ["LinearOperator"]
|
|
66 | 73 | # TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
|
67 | 74 | # @tf_export("linalg.LinearOperator")
|
68 | 75 | @six.add_metaclass(abc.ABCMeta)
|
69 |
| -class LinearOperator(module.Module): |
| 76 | +class LinearOperator(module.Module, composite_tensor.CompositeTensor): |
70 | 77 | """Base class defining a [batch of] linear operator[s].
|
71 | 78 |
|
72 | 79 | Subclasses of `LinearOperator` provide access to common methods on a
|
@@ -1176,6 +1183,201 @@ def _set_graph_parents(self, graph_parents):
|
1176 | 1183 | raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
|
1177 | 1184 | self._graph_parents = graph_parents
|
1178 | 1185 |
|
| 1186 | + @property |
| 1187 | + def _composite_tensor_fields(self): |
| 1188 | + """A tuple of parameter names to rebuild the `LinearOperator`. |
| 1189 | +
|
| 1190 | + The tuple contains the names of kwargs to the `LinearOperator`'s constructor |
| 1191 | + that the `TypeSpec` needs to rebuild the `LinearOperator` instance. |
| 1192 | +
|
| 1193 | + "is_non_singular", "is_self_adjoint", "is_positive_definite", and |
| 1194 | + "is_square" are common to all `LinearOperator` subclasses and may be |
| 1195 | + omitted. |
| 1196 | + """ |
| 1197 | + return () |
| 1198 | + |
| 1199 | + @property |
| 1200 | + def _composite_tensor_prefer_static_fields(self): |
| 1201 | + """A tuple of names referring to parameters that may be treated statically. |
| 1202 | +
|
| 1203 | + This is a subset of `_composite_tensor_fields`, and contains the names of |
| 1204 | + of `Tensor`-like args to the `LinearOperator`s constructor that may be |
| 1205 | + stored as static values, if they are statically known. These are typically |
| 1206 | + shapes or axis values. |
| 1207 | + """ |
| 1208 | + return () |
| 1209 | + |
| 1210 | + @property |
| 1211 | + def _type_spec(self): |
| 1212 | + # This property will be overwritten by the `@make_composite_tensor` |
| 1213 | + # decorator. However, we need it so that a valid subclass of the `ABCMeta` |
| 1214 | + # class `CompositeTensor` can be constructed and passed to the |
| 1215 | + # `@make_composite_tensor` decorator. |
| 1216 | + pass |
| 1217 | + |
| 1218 | + |
| 1219 | +class _LinearOperatorSpec(type_spec.TypeSpec): |
| 1220 | + """A tf.TypeSpec for `LinearOperator` objects.""" |
| 1221 | + |
| 1222 | + __slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields") |
| 1223 | + |
| 1224 | + def __init__(self, param_specs, non_tensor_params, prefer_static_fields): |
| 1225 | + """Initializes a new `_LinearOperatorSpec`. |
| 1226 | +
|
| 1227 | + Args: |
| 1228 | + param_specs: Python `dict` of `tf.TypeSpec` instances that describe |
| 1229 | + kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or |
| 1230 | + `CompositeTensor` subclasses. |
| 1231 | + non_tensor_params: Python `dict` containing non-`Tensor` and non- |
| 1232 | + `CompositeTensor` kwargs to the `LinearOperator`'s constructor. |
| 1233 | + prefer_static_fields: Python `tuple` of strings corresponding to the names |
| 1234 | + of `Tensor`-like args to the `LinearOperator`s constructor that may be |
| 1235 | + stored as static values, if known. These are typically shapes, indices, |
| 1236 | + or axis values. |
| 1237 | + """ |
| 1238 | + self._param_specs = param_specs |
| 1239 | + self._non_tensor_params = non_tensor_params |
| 1240 | + self._prefer_static_fields = prefer_static_fields |
| 1241 | + |
| 1242 | + @classmethod |
| 1243 | + def from_operator(cls, operator): |
| 1244 | + """Builds a `_LinearOperatorSpec` from a `LinearOperator` instance. |
| 1245 | +
|
| 1246 | + Args: |
| 1247 | + operator: An instance of `LinearOperator`. |
| 1248 | +
|
| 1249 | + Returns: |
| 1250 | + linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as |
| 1251 | + the `TypeSpec` of `operator`. |
| 1252 | + """ |
| 1253 | + validation_fields = ("is_non_singular", "is_self_adjoint", |
| 1254 | + "is_positive_definite", "is_square") |
| 1255 | + kwargs = _extract_attrs( |
| 1256 | + operator, |
| 1257 | + keys=set(operator._composite_tensor_fields + validation_fields)) # pylint: disable=protected-access |
| 1258 | + |
| 1259 | + non_tensor_params = {} |
| 1260 | + param_specs = {} |
| 1261 | + for k, v in list(kwargs.items()): |
| 1262 | + type_spec_or_v = _extract_type_spec_recursively(v) |
| 1263 | + is_tensor = [isinstance(x, type_spec.TypeSpec) |
| 1264 | + for x in nest.flatten(type_spec_or_v)] |
| 1265 | + if all(is_tensor): |
| 1266 | + param_specs[k] = type_spec_or_v |
| 1267 | + elif not any(is_tensor): |
| 1268 | + non_tensor_params[k] = v |
| 1269 | + else: |
| 1270 | + raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and " |
| 1271 | + f" non-`Tensor` values.") |
| 1272 | + |
| 1273 | + return cls( |
| 1274 | + param_specs=param_specs, |
| 1275 | + non_tensor_params=non_tensor_params, |
| 1276 | + prefer_static_fields=operator._composite_tensor_prefer_static_fields) # pylint: disable=protected-access |
| 1277 | + |
| 1278 | + def _to_components(self, obj): |
| 1279 | + return _extract_attrs(obj, keys=list(self._param_specs)) |
| 1280 | + |
| 1281 | + def _from_components(self, components): |
| 1282 | + kwargs = dict(self._non_tensor_params, **components) |
| 1283 | + return self.value_type(**kwargs) |
| 1284 | + |
| 1285 | + @property |
| 1286 | + def _component_specs(self): |
| 1287 | + return self._param_specs |
| 1288 | + |
| 1289 | + def _serialize(self): |
| 1290 | + return (self._param_specs, |
| 1291 | + self._non_tensor_params, |
| 1292 | + self._prefer_static_fields) |
| 1293 | + |
| 1294 | + |
| 1295 | +def make_composite_tensor(cls, module_name="tf.linalg"): |
| 1296 | + """Class decorator to convert `LinearOperator`s to `CompositeTensor`.""" |
| 1297 | + |
| 1298 | + spec_name = "{}Spec".format(cls.__name__) |
| 1299 | + spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls}) |
| 1300 | + type_spec.register("{}.{}".format(module_name, spec_name))(spec_type) |
| 1301 | + cls._type_spec = property(spec_type.from_operator) # pylint: disable=protected-access |
| 1302 | + return cls |
| 1303 | + |
| 1304 | + |
| 1305 | +def _extract_attrs(op, keys): |
| 1306 | + """Extract constructor kwargs to reconstruct `op`. |
| 1307 | +
|
| 1308 | + Args: |
| 1309 | + op: A `LinearOperator` instance. |
| 1310 | + keys: A Python `tuple` of strings indicating the names of the constructor |
| 1311 | + kwargs to extract from `op`. |
| 1312 | +
|
| 1313 | + Returns: |
| 1314 | + kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`. |
| 1315 | + """ |
| 1316 | + |
| 1317 | + kwargs = {} |
| 1318 | + not_found = object() |
| 1319 | + for k in keys: |
| 1320 | + srcs = [ |
| 1321 | + getattr(op, k, not_found), getattr(op, "_" + k, not_found), |
| 1322 | + getattr(op, "parameters", {}).get(k, not_found), |
| 1323 | + ] |
| 1324 | + if any(v is not not_found for v in srcs): |
| 1325 | + kwargs[k] = [v for v in srcs if v is not not_found][0] |
| 1326 | + else: |
| 1327 | + raise ValueError( |
| 1328 | + f"Could not determine an appropriate value for field `{k}` in object " |
| 1329 | + f" `{op}`. Looked for \n" |
| 1330 | + f" 1. an attr called `{k}`,\n" |
| 1331 | + f" 2. an attr called `_{k}`,\n" |
| 1332 | + f" 3. an entry in `op.parameters` with key '{k}'.") |
| 1333 | + if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None: # pylint: disable=protected-access |
| 1334 | + if ops.is_tensor(kwargs[k]): |
| 1335 | + static_val = (kwargs[k]) |
| 1336 | + if static_val is not None: |
| 1337 | + kwargs[k] = static_val |
| 1338 | + if isinstance(kwargs[k], (np.ndarray, np.generic)): |
| 1339 | + kwargs[k] = kwargs[k].tolist() |
| 1340 | + return kwargs |
| 1341 | + |
| 1342 | + |
| 1343 | +def _extract_type_spec_recursively(value): |
| 1344 | + """Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s. |
| 1345 | +
|
| 1346 | + If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If |
| 1347 | + `value` is a collection containing `Tensor` values, recursively supplant them |
| 1348 | + with their respective `TypeSpec`s in a collection of parallel stucture. |
| 1349 | +
|
| 1350 | + If `value` is none of the above, return it unchanged. |
| 1351 | +
|
| 1352 | + Args: |
| 1353 | + value: a Python `object` to (possibly) turn into a (collection of) |
| 1354 | + `tf.TypeSpec`(s). |
| 1355 | +
|
| 1356 | + Returns: |
| 1357 | + spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value` |
| 1358 | + or `value`, if no `Tensor`s are found. |
| 1359 | + """ |
| 1360 | + if isinstance(value, composite_tensor.CompositeTensor): |
| 1361 | + return value._type_spec # pylint: disable=protected-access |
| 1362 | + if isinstance(value, variables.Variable): |
| 1363 | + return resource_variable_ops.VariableSpec( |
| 1364 | + tensor_shape.TensorShape(value.shape), dtype=value.dtype, trainable=value.trainable) |
| 1365 | + if ops.is_tensor(value): |
| 1366 | + return tensor_spec.TensorSpec(tensor_shape.TensorShape(value.shape), value.dtype) |
| 1367 | + # Unwrap trackable data structures to comply with `Type_Spec._serialize` |
| 1368 | + # requirements. `ListWrapper`s are converted to `list`s, and for other |
| 1369 | + # trackable data structures, the `__wrapped__` attribute is used. |
| 1370 | + if isinstance(value, list): |
| 1371 | + return list(_extract_type_spec_recursively(v) for v in value) |
| 1372 | + if isinstance(value, data_structures.TrackableDataStructure): |
| 1373 | + return _extract_type_spec_recursively(value.__wrapped__) |
| 1374 | + if isinstance(value, tuple): |
| 1375 | + return type(value)(_extract_type_spec_recursively(x) for x in value) |
| 1376 | + if isinstance(value, dict): |
| 1377 | + return type(value)((k, _extract_type_spec_recursively(v)) |
| 1378 | + for k, v in value.items()) |
| 1379 | + return value |
| 1380 | + |
1179 | 1381 |
|
1180 | 1382 | # Overrides for tf.linalg functions. This allows a LinearOperator to be used in
|
1181 | 1383 | # place of a Tensor.
|
|
0 commit comments