Skip to content

Commit c1aeaf7

Browse files
committed
make multithread-mm a feature (dis by default)
1 parent 33645c1 commit c1aeaf7

File tree

5 files changed

+18
-2
lines changed

5 files changed

+18
-2
lines changed

cli/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ pulse = [ "tract-pulse", "tract-pulse-opl" ]
6363
tf = [ "tract-tensorflow", "tract-libcli/hir" ]
6464
tflite = [ "tract-tflite" ]
6565
conform = [ "tract-tensorflow/conform" ]
66+
multithread-mm = [ "tract-linalg/multithread-mm" ]

cli/src/main.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,16 @@ fn handle(matches: clap::ArgMatches, probe: Option<&Probe>) -> TractResult<()> {
706706

707707
let mut need_optimisations = false;
708708

709+
#[cfg(feature = "multithread-mm")]
709710
if let Some(threads) = matches.value_of("threads") {
710711
let threads: usize = threads.parse()?;
711712
let threads = if threads == 0 { num_cpus::get_physical() } else { threads };
712713
multithread::set_default_executor(multithread::Executor::multithread(threads));
713714
}
715+
#[cfg(not(feature = "multithread-mm"))]
716+
if let Some(_) = matches.value_of("threads") {
717+
bail!("tract is compiled without multithread support")
718+
}
714719

715720
match matches.subcommand() {
716721
Some(("bench", m)) => {

linalg/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ lazy_static.workspace = true
2323
log.workspace = true
2424
num-traits.workspace = true
2525
paste.workspace = true
26-
rayon.workspace = true
26+
rayon = { workspace = true, optional = true }
2727
scan_fmt.workspace = true
2828
tract-data.workspace = true
2929

@@ -61,7 +61,8 @@ proptest = { version = "1.0.0", default-features = false, features = ["std", "bi
6161
# preferred.
6262
no_fp16 = []
6363
apple-amx-ios = []
64-
default = []
64+
default = [ ]
65+
multithread-mm = [ "rayon" ]
6566
complex = [ "tract-data/complex" ]
6667

6768
[[bench]]

linalg/src/frame/mmm/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ mod storage;
1616
pub mod tests;
1717

1818
use crate::multithread::Executor;
19+
#[cfg(feature = "multithread-mm")]
1920
use rayon::prelude::*;
2021
use std::borrow::Cow;
2122
use std::cmp::Ordering;
@@ -223,6 +224,7 @@ unsafe fn run_with_scratch_space_vec<K: MatMatMulKer>(
223224
}
224225
Ok(())
225226
}
227+
#[cfg(feature = "multithread-mm")]
226228
Executor::MultiThread(pool) => pool.install(|| {
227229
(0..m.div_ceil(ker.mr()))
228230
.into_par_iter()
@@ -247,6 +249,7 @@ unsafe fn run_with_scratch_space_col_outer<K: MatMatMulKer>(
247249
}
248250
Ok(())
249251
}
252+
#[cfg(feature = "multithread-mm")]
250253
Executor::MultiThread(pool) => pool.install(|| {
251254
(0..n.div_ceil(ker.nr())).into_par_iter().try_for_each(|ib| {
252255
for ia in 0..m.divceil(ker.mr()) {
@@ -274,6 +277,7 @@ unsafe fn run_with_scratch_space_row_outer<K: MatMatMulKer>(
274277
}
275278
Ok(())
276279
}
280+
#[cfg(feature = "multithread-mm")]
277281
Executor::MultiThread(pool) => pool.install(|| {
278282
pool.install(|| {
279283
(0..m.div_ceil(ker.mr())).into_par_iter().try_for_each(|ia| {

linalg/src/multithread.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
use std::cell::RefCell;
2+
#[allow(unused_imports)]
23
use std::sync::{Arc, Mutex};
34

5+
#[cfg(feature = "multithread-mm")]
46
use rayon::{ThreadPool, ThreadPoolBuilder};
57

68
#[derive(Debug, Clone, Default)]
79
pub enum Executor {
810
#[default]
911
SingleThread,
12+
#[cfg(feature = "multithread-mm")]
1013
MultiThread(Arc<ThreadPool>),
1114
}
1215

1316
impl Executor {
17+
#[cfg(feature = "multithread-mm")]
1418
pub fn multithread(n: usize) -> Executor {
1519
Executor::multithread_with_name(n, "tract-default")
1620
}
1721

22+
#[cfg(feature = "multithread-mm")]
1823
pub fn multithread_with_name(n: usize, name: &str) -> Executor {
1924
let name = name.to_string();
2025
let pool = ThreadPoolBuilder::new()

0 commit comments

Comments
 (0)