1
+ #!/usr/bin/env python
2
+
3
+ #######################################################
4
+ # Copyright (c) 2019, ArrayFire
5
+ # All rights reserved.
6
+ #
7
+ # This file is distributed under 3-clause BSD license.
8
+ # The complete license agreement can be obtained at:
9
+ # http://arrayfire.com/licenses/BSD-3-Clause
10
+ ########################################################
11
+
12
+ from mnist_common import display_results , setup_mnist
13
+
14
+ import sys
15
+ import time
16
+
17
+ import arrayfire as af
18
+
19
+ def accuracy (predicted , target ):
20
+ _ , tlabels = af .imax (target , axis = 1 )
21
+ _ , plabels = af .imax (predicted , axis = 1 )
22
+ return 100 * af .count (plabels == tlabels ) / tlabels .size
23
+
24
+
25
+ def abserr (predicted , target ):
26
+ return 100 * af .sum (af .abs (predicted - target )) / predicted .size
27
+
28
+
29
+ # Predict (probability) based on given parameters
30
+ def predict_prob (X , Weights ):
31
+ Z = af .matmul (X , Weights )
32
+ return af .sigmoid (Z )
33
+
34
+
35
+ # Predict (log probability) based on given parameters
36
+ def predict_log_prob (X , Weights ):
37
+ return af .log (predict_prob (X , Weights ))
38
+
39
+
40
+ # Give most likely class based on given parameters
41
+ def predict_class (X , Weights ):
42
+ probs = predict_prob (X , Weights )
43
+ _ , classes = af .imax (probs , 1 )
44
+ return classes
45
+
46
+
47
+ def cost (Weights , X , Y , lambda_param = 1.0 ):
48
+ # Number of samples
49
+ m = Y .shape [0 ]
50
+
51
+ dim0 = Weights .shape [0 ]
52
+ dim1 = Weights .shape [1 ] if len (Weights .shape ) > 1 else 1
53
+ dim2 = Weights .shape [2 ] if len (Weights .shape ) > 2 else 1
54
+ dim3 = Weights .shape [3 ] if len (Weights .shape ) > 3 else 1
55
+ # Make the lambda corresponding to Weights(0) == 0
56
+ lambdat = af .constant (lambda_param , (dim0 , dim1 , dim2 , dim3 ))
57
+
58
+ # No regularization for bias weights
59
+ lambdat [0 , :] = 0
60
+
61
+ # Get the prediction
62
+ H = predict_prob (X , Weights )
63
+
64
+ # Cost of misprediction
65
+ Jerr = - 1 * af .sum (Y * af .log (H ) + (1 - Y ) * af .log (1 - H ), axis = 0 )
66
+
67
+ # Regularization cost
68
+ Jreg = 0.5 * af .sum (lambdat * Weights * Weights , axis = 0 )
69
+
70
+ # Total cost
71
+ J = (Jerr + Jreg ) / m
72
+
73
+ # Find the gradient of cost
74
+ D = (H - Y )
75
+ dJ = (af .matmul (X , D , af .MatProp .TRANS ) + lambdat * Weights ) / m
76
+
77
+ return J , dJ
78
+
79
+
80
+ def train (X , Y , alpha = 0.1 , lambda_param = 1.0 , maxerr = 0.01 , maxiter = 1000 , verbose = False ):
81
+ # Initialize parameters to 0
82
+ Weights = af .constant (0 , (X .shape [1 ], Y .shape [1 ]))
83
+
84
+ for i in range (maxiter ):
85
+ # Get the cost and gradient
86
+ J , dJ = cost (Weights , X , Y , lambda_param )
87
+
88
+ err = af .max (af .abs (J ))
89
+ if err < maxerr :
90
+ print ('Iteration {0:4d} Err: {1:4f}' .format (i + 1 , err ))
91
+ print ('Training converged' )
92
+ return Weights
93
+
94
+ if verbose and ((i + 1 ) % 10 == 0 ):
95
+ print ('Iteration {0:4d} Err: {1:4f}' .format (i + 1 , err ))
96
+
97
+ # Update the parameters via gradient descent
98
+ Weights = Weights - alpha * dJ
99
+
100
+ if verbose :
101
+ print ('Training stopped after {0:d} iterations' .format (maxiter ))
102
+
103
+ return Weights
104
+
105
+
106
+ def benchmark_logistic_regression (train_feats , train_targets , test_feats ):
107
+ t0 = time .time ()
108
+ Weights = train (train_feats , train_targets , 0.1 , 1.0 , 0.01 , 1000 )
109
+ af .eval (Weights )
110
+ af .sync (- 1 )
111
+ t1 = time .time ()
112
+ dt = t1 - t0
113
+ print ('Training time: {0:4.4f} s' .format (dt ))
114
+
115
+ t0 = time .time ()
116
+ iters = 100
117
+ for i in range (iters ):
118
+ test_outputs = predict_prob (test_feats , Weights )
119
+ af .eval (test_outputs )
120
+ af .sync (- 1 )
121
+ t1 = time .time ()
122
+ dt = t1 - t0
123
+ print ('Prediction time: {0:4.4f} s' .format (dt / iters ))
124
+
125
+
126
+ # Demo of one vs all logistic regression
127
+ def logit_demo (console , perc ):
128
+ # Load mnist data
129
+ frac = float (perc ) / 100.0
130
+ mnist_data = setup_mnist (frac , True )
131
+ num_classes = mnist_data [0 ]
132
+ num_train = mnist_data [1 ]
133
+ num_test = mnist_data [2 ]
134
+ train_images = mnist_data [3 ]
135
+ test_images = mnist_data [4 ]
136
+ train_targets = mnist_data [5 ]
137
+ test_targets = mnist_data [6 ]
138
+
139
+ # Reshape images into feature vectors
140
+ feature_length = int (train_images .size / num_train );
141
+ train_feats = af .transpose (af .moddims (train_images , (feature_length , num_train )))
142
+
143
+
144
+ test_feats = af .transpose (af .moddims (test_images , (feature_length , num_test )))
145
+
146
+ train_targets = af .transpose (train_targets )
147
+ test_targets = af .transpose (test_targets )
148
+
149
+ num_train = train_feats .shape [0 ]
150
+ num_test = test_feats .shape [0 ]
151
+
152
+
153
+ # Add a bias that is always 1
154
+ train_bias = af .constant (1 , (num_train , 1 ))
155
+ test_bias = af .constant (1 , (num_test , 1 ))
156
+ train_feats = af .join (1 , train_bias , train_feats )
157
+ test_feats = af .join (1 , test_bias , test_feats )
158
+
159
+
160
+ # Train logistic regression parameters
161
+ Weights = train (train_feats , train_targets ,
162
+ 0.1 , # learning rate
163
+ 1.0 , # regularization constant
164
+ 0.01 , # max error
165
+ 1000 , # max iters
166
+ True # verbose mode
167
+ )
168
+ af .eval (Weights )
169
+ af .sync (- 1 )
170
+
171
+ # Predict the results
172
+ train_outputs = predict_prob (train_feats , Weights )
173
+ test_outputs = predict_prob (test_feats , Weights )
174
+
175
+ print ('Accuracy on training data: {0:2.2f}' .format (accuracy (train_outputs , train_targets )))
176
+ print ('Accuracy on testing data: {0:2.2f}' .format (accuracy (test_outputs , test_targets )))
177
+ print ('Maximum error on testing data: {0:2.2f}' .format (abserr (test_outputs , test_targets )))
178
+
179
+ benchmark_logistic_regression (train_feats , train_targets , test_feats )
180
+
181
+ if not console :
182
+ test_outputs = af .transpose (test_outputs )
183
+ # Get 20 random test images
184
+ display_results (test_images , test_outputs , af .transpose (test_targets ), 20 , True )
185
+
186
+ def main ():
187
+ argc = len (sys .argv )
188
+
189
+ device = int (sys .argv [1 ]) if argc > 1 else 0
190
+ console = sys .argv [2 ][0 ] == '-' if argc > 2 else False
191
+ perc = int (sys .argv [3 ]) if argc > 3 else 60
192
+
193
+
194
+ try :
195
+ af .set_device (device )
196
+ af .info ()
197
+ logit_demo (console , perc )
198
+ except Exception as e :
199
+ print ('Error: ' , str (e ))
200
+
201
+
202
+ if __name__ == '__main__' :
203
+ main ()
0 commit comments