Skip to content

Commit bea4f5c

Browse files
committed
refactor: make matrix multiply to be concurrency with threads
1 parent d44c2c1 commit bea4f5c

File tree

6 files changed

+200
-17
lines changed

6 files changed

+200
-17
lines changed

Cargo.lock

Lines changed: 15 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
[package]
2-
name = "template"
2+
name = "concurrency"
33
version = "0.1.0"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
99
anyhow = "1.0.86"
10+
oneshot = "0.1.8"
1011
rand = "0.8.5"

examples/matrix.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use anyhow::{Ok, Result};
2+
use concurrency::Matrix;
23

34
fn main() -> Result<()> {
4-
println!("f64 default: {}", f64::default());
5+
let a = Matrix::new([1, 2, 3, 4, 5, 6], 2, 3);
6+
let b = Matrix::new([1, 2, 3, 4, 5, 6], 3, 2);
7+
println!("a * b = {}", a * b);
58
Ok(())
69
}

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
mod matrix;
2-
2+
mod vector;
33
pub use matrix::{multiply, Matrix};
4+
pub use vector::*;

src/matrix.rs

Lines changed: 126 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
use anyhow::{anyhow, Result};
21
use std::{
32
fmt,
43
ops::{Add, AddAssign, Mul},
4+
sync::mpsc,
5+
thread,
56
};
67

8+
use anyhow::{anyhow, Result};
9+
10+
use crate::{dot_product, Vector};
11+
12+
const NUM_THREADS: usize = 4;
13+
714
// 声明一个矩阵的结构
815

916
pub struct Matrix<T> {
@@ -12,30 +19,104 @@ pub struct Matrix<T> {
1219
col: usize,
1320
}
1421

22+
pub struct MsgInput<T> {
23+
idx: usize,
24+
row: Vector<T>,
25+
col: Vector<T>,
26+
}
27+
28+
pub struct MsgOutput<T> {
29+
idx: usize,
30+
value: T,
31+
}
32+
33+
pub struct Msg<T> {
34+
input: MsgInput<T>,
35+
// sender to send the result back
36+
sender: oneshot::Sender<MsgOutput<T>>,
37+
}
38+
1539
pub fn multiply<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>>
1640
where
17-
T: Copy + Default + Add<Output = T> + AddAssign + Mul<Output = T>,
41+
T: Copy + Default + Add<Output = T> + AddAssign + Mul<Output = T> + Send + 'static,
1842
{
1943
if a.col != b.row {
2044
return Err(anyhow!("Matrix multiply error: a.col != b.row"));
2145
}
2246

47+
let senders = (0..NUM_THREADS)
48+
.map(|_| {
49+
let (tx, rx) = mpsc::channel::<Msg<T>>();
50+
// 线程下执行的内容
51+
thread::spawn(move || {
52+
for msg in rx {
53+
let value = dot_product(msg.input.row, msg.input.col)?;
54+
if let Err(e) = msg.sender.send(MsgOutput {
55+
idx: msg.input.idx,
56+
value,
57+
}) {
58+
eprintln!("send error: {:?}", e);
59+
}
60+
}
61+
// 这里我犯了一个错误,加了分号导致找了很长时间的 bug
62+
Ok::<_, anyhow::Error>(())
63+
});
64+
tx
65+
})
66+
.collect::<Vec<_>>();
67+
68+
// generate 4 threads which receive msg and do dot product
69+
2370
// 让数据有缺省值
24-
let mut data = vec![T::default(); a.row * b.col];
71+
let matrix_len = a.row * b.col;
72+
// 初始化一个空的矩阵
73+
let mut data = vec![T::default(); matrix_len];
74+
let mut receivers = Vec::with_capacity(matrix_len);
75+
// map reduce map phase
2576
for i in 0..a.row {
2677
for j in 0..b.col {
27-
for k in 0..a.col {
28-
data[i * b.col + j] += a.data[i * a.col + k] * b.data[k * b.col + j]
78+
let row = Vector::new(&a.data[i * a.col..(i + 1) * a.col]);
79+
let col_data = b.data[j..]
80+
.iter()
81+
.step_by(b.col)
82+
.copied()
83+
.collect::<Vec<_>>();
84+
let col = Vector::new(col_data);
85+
let idx = i * b.col + j;
86+
let input = MsgInput::new(idx, row, col);
87+
let (tx, rx) = oneshot::channel();
88+
let msg = Msg::new(input, tx);
89+
if let Err(e) = senders[idx % NUM_THREADS].send(msg) {
90+
eprintln!("send error: {:?}", e);
2991
}
92+
receivers.push(rx);
3093
}
3194
}
95+
96+
// map reduce: reduce phase
97+
for rx in receivers {
98+
let output = rx.recv()?;
99+
data[output.idx] = output.value;
100+
}
32101
Ok(Matrix {
33102
data,
34103
row: a.row,
35104
col: b.col,
36105
})
37106
}
38107

108+
impl<T> MsgInput<T> {
109+
pub fn new(idx: usize, row: Vector<T>, col: Vector<T>) -> Self {
110+
Self { idx, row, col }
111+
}
112+
}
113+
114+
impl<T> Msg<T> {
115+
pub fn new(input: MsgInput<T>, sender: oneshot::Sender<MsgOutput<T>>) -> Self {
116+
Self { input, sender }
117+
}
118+
}
119+
39120
impl<T: fmt::Debug> Matrix<T> {
40121
pub fn new(data: impl Into<Vec<T>>, row: usize, col: usize) -> Self {
41122
Self {
@@ -79,20 +160,59 @@ where
79160
}
80161
}
81162

163+
impl<T> Mul for Matrix<T>
164+
where
165+
T: Copy + Default + Add<Output = T> + AddAssign + Mul<Output = T> + Send + 'static,
166+
{
167+
type Output = Self;
168+
169+
fn mul(self, rhs: Self) -> Self::Output {
170+
multiply(&self, &rhs).expect("Matrix multiply error")
171+
}
172+
}
173+
82174
#[cfg(test)]
83175
mod tests {
176+
use std::vec;
177+
84178
use super::*;
85179

86180
#[test]
87181
fn test_matrix_multiply() -> Result<()> {
88182
let a = Matrix::new([1, 2, 3, 4, 5, 6], 2, 3);
89183
let b = Matrix::new([1, 2, 3, 4, 5, 6], 3, 2);
90-
let c = multiply(&a, &b)?;
184+
let c = a * b;
91185
assert_eq!(c.col, 2);
92186
assert_eq!(c.row, 2);
93187
assert_eq!(c.data, vec![22, 28, 49, 64]);
94188
assert_eq!(format!("{:?}", c), "Matrix(row=2, col=2, {22 28, 49 64})");
95189

96190
Ok(())
97191
}
192+
193+
#[test]
194+
fn test_matrix_display() -> Result<()> {
195+
let a = Matrix::new([1, 2, 3, 4], 2, 2);
196+
let b = Matrix::new([1, 2, 3, 4], 2, 2);
197+
let c = a * b;
198+
assert_eq!(c.data, vec![7, 10, 15, 22]);
199+
assert_eq!(format!("{}", c), "{7 10, 15 22}");
200+
Ok(())
201+
}
202+
203+
#[test]
204+
#[should_panic]
205+
fn test_a_can_not_multiply_b_panic() {
206+
let a = Matrix::new([1, 2, 3, 4, 5, 6], 2, 3);
207+
let b = Matrix::new([1, 2, 3, 4], 2, 2);
208+
let _c = a * b;
209+
}
210+
211+
#[test]
212+
fn test_a_can_not_multiply_b() {
213+
let a = Matrix::new([1, 2, 3, 4, 5, 6], 2, 3);
214+
let b = Matrix::new([1, 2, 3, 4], 2, 2);
215+
let c = multiply(&a, &b);
216+
assert!(c.is_err());
217+
}
98218
}

src/vector.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use anyhow::{anyhow, Result};
2+
use std::ops::{Add, AddAssign, Deref, Mul};
3+
4+
pub struct Vector<T> {
5+
data: Vec<T>,
6+
}
7+
8+
impl<T> Deref for Vector<T> {
9+
type Target = Vec<T>;
10+
11+
fn deref(&self) -> &Self::Target {
12+
&self.data
13+
}
14+
}
15+
16+
impl<T> Vector<T> {
17+
pub fn new(data: impl Into<Vec<T>>) -> Self {
18+
Self { data: data.into() }
19+
}
20+
21+
// pub fn len(&self) -> usize {
22+
// self.data.len()
23+
// }
24+
25+
// pub fn iter(&self) -> std::slice::Iter<T> {
26+
// self.data.iter()
27+
// }
28+
}
29+
30+
// impl<T> Index<usize> for Vector<T> {
31+
// type Output = T;
32+
// fn index(&self, index: usize) -> &Self::Output {
33+
// &self.data[index]
34+
// }
35+
// }
36+
37+
pub fn dot_product<T>(a: Vector<T>, b: Vector<T>) -> Result<T>
38+
where
39+
T: Copy + Default + Add<Output = T> + AddAssign + Mul<Output = T>,
40+
{
41+
if a.len() != b.len() {
42+
return Err(anyhow!("Dot product error: a.len != b.len"));
43+
}
44+
45+
let mut sum = T::default();
46+
for i in 0..a.len() {
47+
sum += a[i] * b[i];
48+
}
49+
50+
Ok(sum)
51+
}

0 commit comments

Comments
 (0)