Skip to content

Commit c99c86e

Browse files
axchtensorflower-gardener
authored andcommitted
Update mechanically-generated numpy backend to upstream changes.
PiperOrigin-RevId: 386465712
1 parent 14bb570 commit c99c86e

File tree

3 files changed

+99
-18
lines changed

3 files changed

+99
-18
lines changed

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_diag.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -355,17 +355,37 @@ class docstring for definition of shape compatibility.
355355
as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
356356
concatenate to `[..., M, R]`.
357357
"""
358+
def _check_operators_agree(r, l, message):
359+
if (r.range_dimension is not None and
360+
l.domain_dimension is not None and
361+
r.range_dimension != l.domain_dimension):
362+
raise ValueError(message)
363+
358364
if isinstance(x, linear_operator.LinearOperator):
359365
left_operator = self.adjoint() if adjoint else self
360366
right_operator = x.adjoint() if adjoint_arg else x
361367

362-
if (right_operator.range_dimension is not None and
363-
left_operator.domain_dimension is not None and
364-
right_operator.range_dimension != left_operator.domain_dimension):
365-
raise ValueError(
366-
"Operators are incompatible. Expected `x` to have dimension"
367-
" {} but got {}.".format(
368-
left_operator.domain_dimension, right_operator.range_dimension))
368+
_check_operators_agree(
369+
right_operator, left_operator,
370+
"Operators are incompatible. Expected `x` to have dimension"
371+
" {} but got {}.".format(
372+
left_operator.domain_dimension, right_operator.range_dimension))
373+
374+
# We can efficiently multiply BlockDiag LinearOperators if the number of
375+
# blocks agree.
376+
if isinstance(x, LinearOperatorBlockDiag):
377+
if len(left_operator.operators) != len(right_operator.operators):
378+
raise ValueError(
379+
"Can not efficiently multiply two `LinearOperatorBlockDiag`s "
380+
"together when number of blocks differ.")
381+
382+
for o1, o2 in zip(left_operator.operators, right_operator.operators):
383+
_check_operators_agree(
384+
o2, o1,
385+
"Blocks are incompatible. Expected `x` to have dimension"
386+
" {} but got {}.".format(
387+
o1.domain_dimension, o2.range_dimension))
388+
369389
with self._name_scope(name): # pylint: disable=not-callable
370390
return linear_operator_algebra.matmul(left_operator, right_operator)
371391

@@ -531,22 +551,38 @@ def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
531551
raise NotImplementedError(
532552
"Exact solve not implemented for an operator that is expected to "
533553
"not be square.")
534-
if not all(operator.is_square for operator in self.operators):
535-
raise NotImplementedError(
536-
"Exact solve not implemented for an operator whose blocks are not "
537-
"square.")
554+
555+
def _check_operators_agree(r, l, message):
556+
if (r.range_dimension is not None and
557+
l.domain_dimension is not None and
558+
r.range_dimension != l.domain_dimension):
559+
raise ValueError(message)
538560

539561
if isinstance(rhs, linear_operator.LinearOperator):
540562
left_operator = self.adjoint() if adjoint else self
541563
right_operator = rhs.adjoint() if adjoint_arg else rhs
542564

543-
if (right_operator.range_dimension is not None and
544-
left_operator.domain_dimension is not None and
545-
right_operator.range_dimension != left_operator.domain_dimension):
546-
raise ValueError(
547-
"Operators are incompatible. Expected `rhs` to have dimension"
548-
" {} but got {}.".format(
549-
left_operator.domain_dimension, right_operator.range_dimension))
565+
_check_operators_agree(
566+
right_operator, left_operator,
567+
"Operators are incompatible. Expected `x` to have dimension"
568+
" {} but got {}.".format(
569+
left_operator.domain_dimension, right_operator.range_dimension))
570+
571+
# We can efficiently solve BlockDiag LinearOperators if the number of
572+
# blocks agree.
573+
if isinstance(right_operator, LinearOperatorBlockDiag):
574+
if len(left_operator.operators) != len(right_operator.operators):
575+
raise ValueError(
576+
"Can not efficiently solve `LinearOperatorBlockDiag` when "
577+
"number of blocks differ.")
578+
579+
for o1, o2 in zip(left_operator.operators, right_operator.operators):
580+
_check_operators_agree(
581+
o2, o1,
582+
"Blocks are incompatible. Expected `x` to have dimension"
583+
" {} but got {}.".format(
584+
o1.domain_dimension, o2.range_dimension))
585+
550586
with self._name_scope(name): # pylint: disable=not-callable
551587
return linear_operator_algebra.solve(left_operator, right_operator)
552588

tensorflow_probability/python/internal/backend/numpy/gen/matmul_registrations.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator
4141
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra
42+
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_diag
4243
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_circulant
4344
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_composition
4445
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag
@@ -236,6 +237,27 @@ def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
236237
linop_a, linop_b)),
237238
is_square=True)
238239

240+
# Block Diag
241+
242+
243+
@linear_operator_algebra.RegisterMatmul(
244+
linear_operator_block_diag.LinearOperatorBlockDiag,
245+
linear_operator_block_diag.LinearOperatorBlockDiag)
246+
def _matmul_linear_operator_block_diag_block_diag(linop_a, linop_b):
247+
return linear_operator_block_diag.LinearOperatorBlockDiag(
248+
operators=[
249+
o1.matmul(o2) for o1, o2 in zip(
250+
linop_a.operators, linop_b.operators)],
251+
is_non_singular=registrations_util.combined_non_singular_hint(
252+
linop_a, linop_b),
253+
# In general, a product of self-adjoint positive-definite block diagonal
254+
# matrices is not self = self - adjoint.
255+
is_self_adjoint=None,
256+
# In general, a product of positive-definite block diagonal matrices is
257+
# not positive-definite.
258+
is_positive_definite=None,
259+
is_square=True)
260+
239261
import numpy as np
240262
from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg
241263
from tensorflow_probability.python.internal.backend.numpy import ops as _ops

tensorflow_probability/python/internal/backend/numpy/gen/solve_registrations.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator
4141
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra
42+
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_diag
4243
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_circulant
4344
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_composition
4445
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag
@@ -208,6 +209,28 @@ def _solve_linear_operator_circulant_circulant(linop_a, linop_b):
208209
linop_a, linop_b)),
209210
is_square=True)
210211

212+
213+
# Block Diag
214+
215+
216+
@linear_operator_algebra.RegisterSolve(
217+
linear_operator_block_diag.LinearOperatorBlockDiag,
218+
linear_operator_block_diag.LinearOperatorBlockDiag)
219+
def _solve_linear_operator_block_diag_block_diag(linop_a, linop_b):
220+
return linear_operator_block_diag.LinearOperatorBlockDiag(
221+
operators=[
222+
o1.solve(o2) for o1, o2 in zip(
223+
linop_a.operators, linop_b.operators)],
224+
is_non_singular=registrations_util.combined_non_singular_hint(
225+
linop_a, linop_b),
226+
# In general, a solve of self-adjoint positive-definite block diagonal
227+
# matrices is not self = self - adjoint.
228+
is_self_adjoint=None,
229+
# In general, a solve of positive-definite block diagonal matrices is
230+
# not positive-definite.
231+
is_positive_definite=None,
232+
is_square=True)
233+
211234
import numpy as np
212235
from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg
213236
from tensorflow_probability.python.internal.backend.numpy import ops as _ops

0 commit comments

Comments
 (0)