1
- use anyhow:: { anyhow, Result } ;
2
1
use std:: {
3
2
fmt,
4
3
ops:: { Add , AddAssign , Mul } ,
4
+ sync:: mpsc,
5
+ thread,
5
6
} ;
6
7
8
+ use anyhow:: { anyhow, Result } ;
9
+
10
+ use crate :: { dot_product, Vector } ;
11
+
12
+ const NUM_THREADS : usize = 4 ;
13
+
7
14
// 声明一个矩阵的结构
8
15
9
16
pub struct Matrix < T > {
@@ -12,30 +19,104 @@ pub struct Matrix<T> {
12
19
col : usize ,
13
20
}
14
21
22
+ pub struct MsgInput < T > {
23
+ idx : usize ,
24
+ row : Vector < T > ,
25
+ col : Vector < T > ,
26
+ }
27
+
28
+ pub struct MsgOutput < T > {
29
+ idx : usize ,
30
+ value : T ,
31
+ }
32
+
33
+ pub struct Msg < T > {
34
+ input : MsgInput < T > ,
35
+ // sender to send the result back
36
+ sender : oneshot:: Sender < MsgOutput < T > > ,
37
+ }
38
+
15
39
pub fn multiply < T > ( a : & Matrix < T > , b : & Matrix < T > ) -> Result < Matrix < T > >
16
40
where
17
- T : Copy + Default + Add < Output = T > + AddAssign + Mul < Output = T > ,
41
+ T : Copy + Default + Add < Output = T > + AddAssign + Mul < Output = T > + Send + ' static ,
18
42
{
19
43
if a. col != b. row {
20
44
return Err ( anyhow ! ( "Matrix multiply error: a.col != b.row" ) ) ;
21
45
}
22
46
47
+ let senders = ( 0 ..NUM_THREADS )
48
+ . map ( |_| {
49
+ let ( tx, rx) = mpsc:: channel :: < Msg < T > > ( ) ;
50
+ // 线程下执行的内容
51
+ thread:: spawn ( move || {
52
+ for msg in rx {
53
+ let value = dot_product ( msg. input . row , msg. input . col ) ?;
54
+ if let Err ( e) = msg. sender . send ( MsgOutput {
55
+ idx : msg. input . idx ,
56
+ value,
57
+ } ) {
58
+ eprintln ! ( "send error: {:?}" , e) ;
59
+ }
60
+ }
61
+ // 这里我犯了一个错误,加了分号导致找了很长时间的 bug
62
+ Ok :: < _ , anyhow:: Error > ( ( ) )
63
+ } ) ;
64
+ tx
65
+ } )
66
+ . collect :: < Vec < _ > > ( ) ;
67
+
68
+ // generate 4 threads which receive msg and do dot product
69
+
23
70
// 让数据有缺省值
24
- let mut data = vec ! [ T :: default ( ) ; a. row * b. col] ;
71
+ let matrix_len = a. row * b. col ;
72
+ // 初始化一个空的矩阵
73
+ let mut data = vec ! [ T :: default ( ) ; matrix_len] ;
74
+ let mut receivers = Vec :: with_capacity ( matrix_len) ;
75
+ // map reduce map phase
25
76
for i in 0 ..a. row {
26
77
for j in 0 ..b. col {
27
- for k in 0 ..a. col {
28
- data[ i * b. col + j] += a. data [ i * a. col + k] * b. data [ k * b. col + j]
78
+ let row = Vector :: new ( & a. data [ i * a. col ..( i + 1 ) * a. col ] ) ;
79
+ let col_data = b. data [ j..]
80
+ . iter ( )
81
+ . step_by ( b. col )
82
+ . copied ( )
83
+ . collect :: < Vec < _ > > ( ) ;
84
+ let col = Vector :: new ( col_data) ;
85
+ let idx = i * b. col + j;
86
+ let input = MsgInput :: new ( idx, row, col) ;
87
+ let ( tx, rx) = oneshot:: channel ( ) ;
88
+ let msg = Msg :: new ( input, tx) ;
89
+ if let Err ( e) = senders[ idx % NUM_THREADS ] . send ( msg) {
90
+ eprintln ! ( "send error: {:?}" , e) ;
29
91
}
92
+ receivers. push ( rx) ;
30
93
}
31
94
}
95
+
96
+ // map reduce: reduce phase
97
+ for rx in receivers {
98
+ let output = rx. recv ( ) ?;
99
+ data[ output. idx ] = output. value ;
100
+ }
32
101
Ok ( Matrix {
33
102
data,
34
103
row : a. row ,
35
104
col : b. col ,
36
105
} )
37
106
}
38
107
108
+ impl < T > MsgInput < T > {
109
+ pub fn new ( idx : usize , row : Vector < T > , col : Vector < T > ) -> Self {
110
+ Self { idx, row, col }
111
+ }
112
+ }
113
+
114
+ impl < T > Msg < T > {
115
+ pub fn new ( input : MsgInput < T > , sender : oneshot:: Sender < MsgOutput < T > > ) -> Self {
116
+ Self { input, sender }
117
+ }
118
+ }
119
+
39
120
impl < T : fmt:: Debug > Matrix < T > {
40
121
pub fn new ( data : impl Into < Vec < T > > , row : usize , col : usize ) -> Self {
41
122
Self {
@@ -79,20 +160,59 @@ where
79
160
}
80
161
}
81
162
163
+ impl < T > Mul for Matrix < T >
164
+ where
165
+ T : Copy + Default + Add < Output = T > + AddAssign + Mul < Output = T > + Send + ' static ,
166
+ {
167
+ type Output = Self ;
168
+
169
+ fn mul ( self , rhs : Self ) -> Self :: Output {
170
+ multiply ( & self , & rhs) . expect ( "Matrix multiply error" )
171
+ }
172
+ }
173
+
82
174
#[ cfg( test) ]
83
175
mod tests {
176
+ use std:: vec;
177
+
84
178
use super :: * ;
85
179
86
180
#[ test]
87
181
fn test_matrix_multiply ( ) -> Result < ( ) > {
88
182
let a = Matrix :: new ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , 2 , 3 ) ;
89
183
let b = Matrix :: new ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , 3 , 2 ) ;
90
- let c = multiply ( & a , & b ) ? ;
184
+ let c = a * b ;
91
185
assert_eq ! ( c. col, 2 ) ;
92
186
assert_eq ! ( c. row, 2 ) ;
93
187
assert_eq ! ( c. data, vec![ 22 , 28 , 49 , 64 ] ) ;
94
188
assert_eq ! ( format!( "{:?}" , c) , "Matrix(row=2, col=2, {22 28, 49 64})" ) ;
95
189
96
190
Ok ( ( ) )
97
191
}
192
+
193
+ #[ test]
194
+ fn test_matrix_display ( ) -> Result < ( ) > {
195
+ let a = Matrix :: new ( [ 1 , 2 , 3 , 4 ] , 2 , 2 ) ;
196
+ let b = Matrix :: new ( [ 1 , 2 , 3 , 4 ] , 2 , 2 ) ;
197
+ let c = a * b;
198
+ assert_eq ! ( c. data, vec![ 7 , 10 , 15 , 22 ] ) ;
199
+ assert_eq ! ( format!( "{}" , c) , "{7 10, 15 22}" ) ;
200
+ Ok ( ( ) )
201
+ }
202
+
203
+ #[ test]
204
+ #[ should_panic]
205
+ fn test_a_can_not_multiply_b_panic ( ) {
206
+ let a = Matrix :: new ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , 2 , 3 ) ;
207
+ let b = Matrix :: new ( [ 1 , 2 , 3 , 4 ] , 2 , 2 ) ;
208
+ let _c = a * b;
209
+ }
210
+
211
+ #[ test]
212
+ fn test_a_can_not_multiply_b ( ) {
213
+ let a = Matrix :: new ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , 2 , 3 ) ;
214
+ let b = Matrix :: new ( [ 1 , 2 , 3 , 4 ] , 2 , 2 ) ;
215
+ let c = multiply ( & a, & b) ;
216
+ assert ! ( c. is_err( ) ) ;
217
+ }
98
218
}
0 commit comments