Skip to content

Commit 27e347c

Browse files
committed
blas: Update layout logic for gemm
We compute A B -> C with matrices A, B, C With the blas (cblas) interface it supports matrices that adhere to certain criteria. They should be contiguous on one dimension (stride=1). We glance a little at how numpy does this to try to catch all cases. In short, we accept A, B contiguous on either axis (row or column major). We use the case where C is (weakly) row major, but if it is column major we transpose A, B, C => A^t, B^t, C^t so that we are back to the C row major case. (Weakly = contiguous with stride=1 on that inner dimension, but stride for the other dimension can be larger; to differentiate from strictly whole array contiguous.) Minor change to the gemv function, no functional change, only updating due to the refactoring of blas layout functions. Fixes #1278
1 parent 2ca801c commit 27e347c

File tree

4 files changed

+278
-139
lines changed

4 files changed

+278
-139
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ rawpointer = { version = "0.2" }
4747
defmac = "0.2"
4848
quickcheck = { workspace = true }
4949
approx = { workspace = true, default-features = true }
50-
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
50+
itertools = { workspace = true }
5151

5252
[features]
5353
default = ["std"]
@@ -73,6 +73,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"]
7373

7474
portable-atomic-critical-section = ["portable-atomic/critical-section"]
7575

76+
7677
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
7778
portable-atomic = { version = "1.6.0" }
7879
portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] }
@@ -103,6 +104,7 @@ approx = { version = "0.5", default-features = false }
103104
quickcheck = { version = "1.0", default-features = false }
104105
rand = { version = "0.8.0", features = ["small_rng"] }
105106
rand_distr = { version = "0.4.0" }
107+
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
106108

107109
[profile.bench]
108110
debug = true

crates/blas-tests/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ doctest = false
1212

1313
[dependencies]
1414
ndarray = { workspace = true, features = ["approx", "blas"] }
15+
ndarray-gen = { workspace = true }
1516

1617
blas-src = { version = "0.10", optional = true }
1718
openblas-src = { version = "0.10", optional = true }
@@ -23,6 +24,7 @@ defmac = "0.2"
2324
approx = { workspace = true }
2425
num-traits = { workspace = true }
2526
num-complex = { workspace = true }
27+
itertools = { workspace = true }
2628

2729
[features]
2830
# Just for making an example and to help testing, , multiple different possible

crates/blas-tests/tests/oper.rs

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ use ndarray::prelude::*;
99

1010
use ndarray::linalg::general_mat_mul;
1111
use ndarray::linalg::general_mat_vec_mul;
12+
use ndarray::Order;
1213
use ndarray::{Data, Ix, LinalgScalar};
14+
use ndarray_gen::array_builder::ArrayBuilder;
1315

1416
use approx::assert_relative_eq;
1517
use defmac::defmac;
18+
use itertools::iproduct;
1619
use num_complex::Complex32;
1720
use num_complex::Complex64;
1821

@@ -243,32 +246,56 @@ fn gen_mat_mul()
243246
let sizes = vec![
244247
(4, 4, 4),
245248
(8, 8, 8),
246-
(17, 15, 16),
249+
(10, 10, 10),
250+
(8, 8, 1),
251+
(1, 10, 10),
252+
(10, 1, 10),
253+
(10, 10, 1),
254+
(1, 10, 1),
255+
(10, 1, 1),
256+
(1, 1, 10),
247257
(4, 17, 3),
248258
(17, 3, 22),
249259
(19, 18, 2),
250260
(16, 17, 15),
251261
(15, 16, 17),
252262
(67, 63, 62),
253263
];
254-
// test different strides
255-
for &s1 in &[1, 2, -1, -2] {
256-
for &s2 in &[1, 2, -1, -2] {
257-
for &(m, k, n) in &sizes {
258-
let a = range_mat64(m, k);
259-
let b = range_mat64(k, n);
260-
let mut c = range_mat64(m, n);
264+
let strides = &[1, 2, -1, -2];
265+
let cf_order = [Order::C, Order::F];
266+
267+
// test different strides and memory orders
268+
for (&s1, &s2) in iproduct!(strides, strides) {
269+
for &(m, k, n) in &sizes {
270+
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
271+
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
272+
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
273+
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
274+
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();
275+
261276
let mut answer = c.clone();
262277

263278
{
264-
let a = a.slice(s![..;s1, ..;s2]);
265-
let b = b.slice(s![..;s2, ..;s2]);
266-
let mut cv = c.slice_mut(s![..;s1, ..;s2]);
279+
let av;
280+
let bv;
281+
let mut cv;
282+
283+
if s1 != 1 || s2 != 1 {
284+
av = a.slice(s![..;s1, ..;s2]);
285+
bv = b.slice(s![..;s2, ..;s2]);
286+
cv = c.slice_mut(s![..;s1, ..;s2]);
287+
} else {
288+
// different stride cases for slicing versus not sliced (for axes of
289+
// len=1); so test not sliced here.
290+
av = a.view();
291+
bv = b.view();
292+
cv = c.view_mut();
293+
}
267294

268-
let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv;
295+
let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv;
269296
answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part);
270297

271-
general_mat_mul(alpha, &a, &b, beta, &mut cv);
298+
general_mat_mul(alpha, &av, &bv, beta, &mut cv);
272299
}
273300
assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
274301
}

0 commit comments

Comments
 (0)