@@ -355,17 +355,37 @@ class docstring for definition of shape compatibility.
355
355
as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
356
356
concatenate to `[..., M, R]`.
357
357
"""
358
+ def _check_operators_agree (r , l , message ):
359
+ if (r .range_dimension is not None and
360
+ l .domain_dimension is not None and
361
+ r .range_dimension != l .domain_dimension ):
362
+ raise ValueError (message )
363
+
358
364
if isinstance (x , linear_operator .LinearOperator ):
359
365
left_operator = self .adjoint () if adjoint else self
360
366
right_operator = x .adjoint () if adjoint_arg else x
361
367
362
- if (right_operator .range_dimension is not None and
363
- left_operator .domain_dimension is not None and
364
- right_operator .range_dimension != left_operator .domain_dimension ):
365
- raise ValueError (
366
- "Operators are incompatible. Expected `x` to have dimension"
367
- " {} but got {}." .format (
368
- left_operator .domain_dimension , right_operator .range_dimension ))
368
+ _check_operators_agree (
369
+ right_operator , left_operator ,
370
+ "Operators are incompatible. Expected `x` to have dimension"
371
+ " {} but got {}." .format (
372
+ left_operator .domain_dimension , right_operator .range_dimension ))
373
+
374
+ # We can efficiently multiply BlockDiag LinearOperators if the number of
375
+ # blocks agree.
376
+ if isinstance (x , LinearOperatorBlockDiag ):
377
+ if len (left_operator .operators ) != len (right_operator .operators ):
378
+ raise ValueError (
379
+ "Can not efficiently multiply two `LinearOperatorBlockDiag`s "
380
+ "together when number of blocks differ." )
381
+
382
+ for o1 , o2 in zip (left_operator .operators , right_operator .operators ):
383
+ _check_operators_agree (
384
+ o2 , o1 ,
385
+ "Blocks are incompatible. Expected `x` to have dimension"
386
+ " {} but got {}." .format (
387
+ o1 .domain_dimension , o2 .range_dimension ))
388
+
369
389
with self ._name_scope (name ): # pylint: disable=not-callable
370
390
return linear_operator_algebra .matmul (left_operator , right_operator )
371
391
@@ -531,22 +551,38 @@ def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
531
551
raise NotImplementedError (
532
552
"Exact solve not implemented for an operator that is expected to "
533
553
"not be square." )
534
- if not all (operator .is_square for operator in self .operators ):
535
- raise NotImplementedError (
536
- "Exact solve not implemented for an operator whose blocks are not "
537
- "square." )
554
+
555
+ def _check_operators_agree (r , l , message ):
556
+ if (r .range_dimension is not None and
557
+ l .domain_dimension is not None and
558
+ r .range_dimension != l .domain_dimension ):
559
+ raise ValueError (message )
538
560
539
561
if isinstance (rhs , linear_operator .LinearOperator ):
540
562
left_operator = self .adjoint () if adjoint else self
541
563
right_operator = rhs .adjoint () if adjoint_arg else rhs
542
564
543
- if (right_operator .range_dimension is not None and
544
- left_operator .domain_dimension is not None and
545
- right_operator .range_dimension != left_operator .domain_dimension ):
546
- raise ValueError (
547
- "Operators are incompatible. Expected `rhs` to have dimension"
548
- " {} but got {}." .format (
549
- left_operator .domain_dimension , right_operator .range_dimension ))
565
+ _check_operators_agree (
566
+ right_operator , left_operator ,
567
+ "Operators are incompatible. Expected `x` to have dimension"
568
+ " {} but got {}." .format (
569
+ left_operator .domain_dimension , right_operator .range_dimension ))
570
+
571
+ # We can efficiently solve BlockDiag LinearOperators if the number of
572
+ # blocks agree.
573
+ if isinstance (right_operator , LinearOperatorBlockDiag ):
574
+ if len (left_operator .operators ) != len (right_operator .operators ):
575
+ raise ValueError (
576
+ "Can not efficiently solve `LinearOperatorBlockDiag` when "
577
+ "number of blocks differ." )
578
+
579
+ for o1 , o2 in zip (left_operator .operators , right_operator .operators ):
580
+ _check_operators_agree (
581
+ o2 , o1 ,
582
+ "Blocks are incompatible. Expected `x` to have dimension"
583
+ " {} but got {}." .format (
584
+ o1 .domain_dimension , o2 .range_dimension ))
585
+
550
586
with self ._name_scope (name ): # pylint: disable=not-callable
551
587
return linear_operator_algebra .solve (left_operator , right_operator )
552
588
0 commit comments