Skip to content

Commit 69993f2

Browse files
committed
Implement Kronecker product
1 parent 6e19f04 commit 69993f2

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

src/linalg/impl_linalg.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9+
use crate::OwnedRepr;
910
use crate::imp_prelude::*;
1011
use crate::numeric_util;
1112
#[cfg(feature = "blas")]
@@ -14,6 +15,7 @@ use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
1415
use crate::{LinalgScalar, Zip};
1516

1617
use std::any::TypeId;
18+
use std::mem::MaybeUninit;
1719
use alloc::vec::Vec;
1820

1921
#[cfg(feature = "blas")]
@@ -699,6 +701,32 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
699701
}
700702
}
701703

704+
705+
/// Kronecker product of 2D matrices.
706+
///
707+
/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
708+
/// matrix K formed by the block multiplication A_ij * B.
709+
pub fn kron<'a, A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &'a ArrayBase<S2, Ix2>) -> ArrayBase<OwnedRepr<A>, Ix2>
710+
where
711+
S1: Data<Elem = A>,
712+
S2: Data<Elem = A>,
713+
A: LinalgScalar,
714+
A: std::ops::Mul<&'a ArrayBase<S2, Ix2>, Output = ArrayBase<OwnedRepr<A>, Ix2>>,
715+
{
716+
let dimar = a.shape()[0];
717+
let dimac = a.shape()[1];
718+
let dimbr = b.shape()[0];
719+
let dimbc = b.shape()[1];
720+
let mut out: Array2<MaybeUninit<A>> = Array2::uninit((dimar * dimbr, dimac * dimbc));
721+
Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
722+
.and(a)
723+
.for_each(|out, a| {
724+
(*a * b).assign_to(out);
725+
});
726+
unsafe { out.assume_init() }
727+
}
728+
729+
702730
#[inline(always)]
703731
/// Return `true` if `A` and `B` are the same type
704732
fn same_type<A: 'static, B: 'static>() -> bool {

src/linalg/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
pub use self::impl_linalg::general_mat_mul;
1212
pub use self::impl_linalg::general_mat_vec_mul;
1313
pub use self::impl_linalg::Dot;
14+
pub use self::impl_linalg::kron;
1415

1516
mod impl_linalg;

tests/oper.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)]
77
#![cfg(feature = "std")]
88
use ndarray::linalg::general_mat_mul;
9+
use ndarray::linalg::kron;
910
use ndarray::prelude::*;
1011
use ndarray::{rcarr1, rcarr2};
1112
use ndarray::{Data, LinalgScalar};
@@ -820,3 +821,65 @@ fn vec_mat_mul() {
820821
}
821822
}
822823
}
824+
825+
#[test]
826+
fn kron_square_f64() {
827+
let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
828+
let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]);
829+
830+
assert_eq!(
831+
kron(&a, &b),
832+
arr2(&[
833+
[0.0, 1.0, 0.0, 0.0],
834+
[1.0, 0.0, 0.0, 0.0],
835+
[0.0, 0.0, 0.0, 1.0],
836+
[0.0, 0.0, 1.0, 0.0]
837+
]),
838+
);
839+
840+
assert_eq!(
841+
kron(&b, &a),
842+
arr2(&[
843+
[0.0, 0.0, 1.0, 0.0],
844+
[0.0, 0.0, 0.0, 1.0],
845+
[1.0, 0.0, 0.0, 0.0],
846+
[0.0, 1.0, 0.0, 0.0]
847+
]),
848+
)
849+
}
850+
851+
#[test]
852+
fn kron_square_i64() {
853+
let a = arr2(&[[1, 0], [0, 1]]);
854+
let b = arr2(&[[0, 1], [1, 0]]);
855+
856+
assert_eq!(
857+
kron(&a, &b),
858+
arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]),
859+
);
860+
861+
assert_eq!(
862+
kron(&b, &a),
863+
arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]),
864+
)
865+
}
866+
867+
#[test]
868+
fn kron_i64() {
869+
let a = arr2(&[[1, 0]]);
870+
let b = arr2(&[[0, 1], [1, 0]]);
871+
let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]);
872+
assert_eq!(kron(&a, &b), r);
873+
874+
let a = arr2(&[[1, 0], [0, 0], [0, 1]]);
875+
let b = arr2(&[[0, 1], [1, 0]]);
876+
let r = arr2(&[
877+
[0, 1, 0, 0],
878+
[1, 0, 0, 0],
879+
[0, 0, 0, 0],
880+
[0, 0, 0, 0],
881+
[0, 0, 0, 1],
882+
[0, 0, 1, 0],
883+
]);
884+
assert_eq!(kron(&a, &b), r);
885+
}

0 commit comments

Comments
 (0)