@@ -11,7 +11,8 @@ public class Program
1111 static readonly string inputDataDirectoryPath = Path . Combine ( Environment . CurrentDirectory , ".." , "pieces" ) ;
1212 static readonly string outputModelFilePath = Path . Combine ( Environment . CurrentDirectory , "model.zip" ) ;
1313 static MLContext mlContext = new MLContext ( seed : 1 ) ;
14- static ITransformer mlModel ;
14+ private static TextWriter outBack ;
15+ private static TextWriter errBack ;
1516
1617 public class ModelInput
1718 {
@@ -24,11 +25,9 @@ public class ModelOutput
2425 public String PredictedLabel { get ; set ; }
2526 }
2627
27- static void TrainModel ( ImageClassificationTrainer . Architecture architecture )
28+ static ( ITransformer mlModel , IReadOnlyList < TrainCatalogBase . CrossValidationResult < MulticlassClassificationMetrics > > evaluation ) TrainModel ( ImageClassificationTrainer . Architecture architecture , int epoch )
2829 {
2930 // To suppress errors from the TensorFlow library, set $env:TF_CPP_MIN_LOG_LEVEL = 2
30-
31- // Create the input dataset
3231 var inputs = new List < ModelInput > ( ) ;
3332 foreach ( var subDir in Directory . GetDirectories ( inputDataDirectoryPath ) )
3433 {
@@ -38,72 +37,53 @@ static void TrainModel(ImageClassificationTrainer.Architecture architecture)
3837 }
3938 }
4039 var trainingDataView = mlContext . Data . LoadFromEnumerable < ModelInput > ( inputs ) ;
41- // Create training pipeline
4240 var dataProcessPipeline = mlContext . Transforms . Conversion . MapValueToKey ( "Label" , "Label" )
4341 . Append ( mlContext . Transforms . LoadRawImageBytes ( "ImageSource_featurized" , null , "ImageSource" ) )
4442 . Append ( mlContext . Transforms . CopyColumns ( "Features" , "ImageSource_featurized" ) ) ;
4543 var trainer = mlContext . MulticlassClassification . Trainers . ImageClassification (
4644 new ImageClassificationTrainer . Options ( )
4745 {
4846 Arch = architecture ,
49- LabelColumnName = "Label" ,
47+ Epoch = epoch ,
5048 FeatureColumnName = "Features" ,
49+ LabelColumnName = "Label" ,
5150 } )
5251 . Append ( mlContext . Transforms . Conversion . MapKeyToValue ( "PredictedLabel" , "PredictedLabel" ) ) ;
5352 IEstimator < ITransformer > trainingPipeline = dataProcessPipeline . Append ( trainer ) ;
54- // Create the model
55- mlModel = trainingPipeline . Fit ( trainingDataView ) ;
56- Evaluate ( mlContext , trainingDataView , trainingPipeline ) ;
53+ var mlModel = trainingPipeline . Fit ( trainingDataView ) ;
54+ var evaluation = mlContext . MulticlassClassification . CrossValidate ( trainingDataView , trainingPipeline , numberOfFolds : 5 , labelColumnName : "Label" ) ;
55+ return ( mlModel , evaluation ) ;
5756 }
5857
59- static ModelOutput Classify ( string filePath )
58+ static ModelOutput Classify ( PredictionEngine < ModelInput , ModelOutput > predEngine , string filePath )
6059 {
61- // Create input to classify
6260 ModelInput input = new ModelInput ( ) { ImageSource = filePath } ;
63- // Load model and predict
64- var predEngine = mlContext . Model . CreatePredictionEngine < ModelInput , ModelOutput > ( mlModel ) ;
6561 return predEngine . Predict ( input ) ;
6662 }
6763
68- static void Evaluate ( MLContext mlContext , IDataView trainingDataView , IEstimator < ITransformer > trainingPipeline )
64+ static Dictionary < string , ( double Avg , double StdDev ) > CalculateAndPrintAverageMetrics ( IEnumerable < TrainCatalogBase . CrossValidationResult < MulticlassClassificationMetrics > > crossValResults )
6965 {
70- Console . WriteLine ( "=============== Cross-validating to get model's accuracy metrics ===============" ) ;
71- var crossValidationResults = mlContext . MulticlassClassification . CrossValidate ( trainingDataView , trainingPipeline , numberOfFolds : 5 , labelColumnName : "Label" ) ;
72- PrintMulticlassClassificationFoldsAverageMetrics ( crossValidationResults ) ;
66+ var metricsInMultipleFolds = crossValResults . Select ( r => r . Metrics ) ;
67+
68+ var retVal = new Dictionary < string , ( double Avg , double StdDev ) > ( ) ;
69+
70+ retVal [ "MicroAccuracy" ] = CalculateAverageMetrics ( metricsInMultipleFolds . Select ( m => m . MicroAccuracy ) ) ;
71+ retVal [ "MacroAccuracy" ] = CalculateAverageMetrics ( metricsInMultipleFolds . Select ( m => m . MacroAccuracy ) ) ;
72+ retVal [ "LogLoss" ] = CalculateAverageMetrics ( metricsInMultipleFolds . Select ( m => m . LogLoss ) ) ;
73+ retVal [ "LogLossReduction" ] = CalculateAverageMetrics ( metricsInMultipleFolds . Select ( m => m . LogLossReduction ) ) ;
74+
75+ Console . WriteLine ( $ "Avg. MicroAccuracy (Std. Dev): { retVal [ "MicroAccuracy" ] . Avg : 0.###} ({ retVal [ "MicroAccuracy" ] . StdDev : #.###} )") ;
76+ Console . WriteLine ( $ "Avg. MacroAccuracy (Std. Dev): { retVal [ "MacroAccuracy" ] . Avg : 0.###} ({ retVal [ "MacroAccuracy" ] . StdDev : #.###} )") ;
77+ Console . WriteLine ( $ "Avg. LogLoss (Std. Dev): { retVal [ "LogLoss" ] . Avg : #.###} ({ retVal [ "LogLoss" ] . StdDev : #.###} )") ;
78+ Console . WriteLine ( $ "Avg. LogLossReduction (Std. Dev): { retVal [ "LogLossReduction" ] . Avg : #.###} ({ retVal [ "LogLossReduction" ] . StdDev : #.###} )") ;
79+
80+ return retVal ;
7381 }
7482
75- static void PrintMulticlassClassificationFoldsAverageMetrics ( IEnumerable < TrainCatalogBase . CrossValidationResult < MulticlassClassificationMetrics > > crossValResults )
83+ static ( double , double ) CalculateAverageMetrics ( IEnumerable < double > metricValues )
7684 {
77- var metricsInMultipleFolds = crossValResults . Select ( r => r . Metrics ) ;
78-
79- var microAccuracyValues = metricsInMultipleFolds . Select ( m => m . MicroAccuracy ) ;
80- var microAccuracyAverage = microAccuracyValues . Average ( ) ;
81- var microAccuraciesStdDeviation = CalculateStandardDeviation ( microAccuracyValues ) ;
82- var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95 ( microAccuracyValues ) ;
83-
84- var macroAccuracyValues = metricsInMultipleFolds . Select ( m => m . MacroAccuracy ) ;
85- var macroAccuracyAverage = macroAccuracyValues . Average ( ) ;
86- var macroAccuraciesStdDeviation = CalculateStandardDeviation ( macroAccuracyValues ) ;
87- var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95 ( macroAccuracyValues ) ;
88-
89- var logLossValues = metricsInMultipleFolds . Select ( m => m . LogLoss ) ;
90- var logLossAverage = logLossValues . Average ( ) ;
91- var logLossStdDeviation = CalculateStandardDeviation ( logLossValues ) ;
92- var logLossConfidenceInterval95 = CalculateConfidenceInterval95 ( logLossValues ) ;
93-
94- var logLossReductionValues = metricsInMultipleFolds . Select ( m => m . LogLossReduction ) ;
95- var logLossReductionAverage = logLossReductionValues . Average ( ) ;
96- var logLossReductionStdDeviation = CalculateStandardDeviation ( logLossReductionValues ) ;
97- var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95 ( logLossReductionValues ) ;
98-
99- Console . WriteLine ( $ "*************************************************************************************************************") ;
100- Console . WriteLine ( $ "* Metrics for Multi-class Classification model ") ;
101- Console . WriteLine ( $ "*------------------------------------------------------------------------------------------------------------") ;
102- Console . WriteLine ( $ "* Average MicroAccuracy: { microAccuracyAverage : 0.###} - Standard deviation: ({ microAccuraciesStdDeviation : #.###} ) - Confidence Interval 95%: ({ microAccuraciesConfidenceInterval95 : #.###} )") ;
103- Console . WriteLine ( $ "* Average MacroAccuracy: { macroAccuracyAverage : 0.###} - Standard deviation: ({ macroAccuraciesStdDeviation : #.###} ) - Confidence Interval 95%: ({ macroAccuraciesConfidenceInterval95 : #.###} )") ;
104- Console . WriteLine ( $ "* Average LogLoss: { logLossAverage : #.###} - Standard deviation: ({ logLossStdDeviation : #.###} ) - Confidence Interval 95%: ({ logLossConfidenceInterval95 : #.###} )") ;
105- Console . WriteLine ( $ "* Average LogLossReduction: { logLossReductionAverage : #.###} - Standard deviation: ({ logLossReductionStdDeviation : #.###} ) - Confidence Interval 95%: ({ logLossReductionConfidenceInterval95 : #.###} )") ;
106- Console . WriteLine ( $ "*************************************************************************************************************") ;
85+ return ( metricValues . Average ( ) ,
86+ CalculateStandardDeviation ( metricValues ) ) ;
10787 }
10888
10989 static double CalculateStandardDeviation ( IEnumerable < double > values )
@@ -114,45 +94,64 @@ static double CalculateStandardDeviation(IEnumerable<double> values)
11494 return standardDeviation ;
11595 }
11696
117- static double CalculateConfidenceInterval95 ( IEnumerable < double > values )
97+ static void TestClassifier ( ITransformer model )
11898 {
119- double confidenceInterval95 = 1.96 * CalculateStandardDeviation ( values ) / Math . Sqrt ( ( values . Count ( ) - 1 ) ) ;
120- return confidenceInterval95 ;
121- }
99+ var predEngine = mlContext . Model . CreatePredictionEngine < ModelInput , ModelOutput > ( model ) ;
122100
123- static void TestClassifier ( )
124- {
125- var result = Classify ( Path . Combine ( Environment . CurrentDirectory , "Black.jpg" ) ) ;
101+ var result = Classify ( predEngine , Path . Combine ( Environment . CurrentDirectory , "Black.jpg" ) ) ;
126102 Console . WriteLine ( $ "Testing with black piece. Prediction: { result . PredictedLabel } .") ;
127- result = Classify ( Path . Combine ( Environment . CurrentDirectory , "Blue.jpg" ) ) ;
103+ result = Classify ( predEngine , Path . Combine ( Environment . CurrentDirectory , "Blue.jpg" ) ) ;
128104 Console . WriteLine ( $ "Testing with blue piece. Prediction: { result . PredictedLabel } .") ;
129- result = Classify ( Path . Combine ( Environment . CurrentDirectory , "Green.jpg" ) ) ;
105+ result = Classify ( predEngine , Path . Combine ( Environment . CurrentDirectory , "Green.jpg" ) ) ;
130106 Console . WriteLine ( $ "Testing with green piece. Prediction: { result . PredictedLabel } .") ;
131- result = Classify ( Path . Combine ( Environment . CurrentDirectory , "Yellow.jpg" ) ) ;
107+ result = Classify ( predEngine , Path . Combine ( Environment . CurrentDirectory , "Yellow.jpg" ) ) ;
132108 Console . WriteLine ( $ "Testing with yellow piece. Prediction: { result . PredictedLabel } .") ;
133109 }
134110
135111 static void Main ( )
136112 {
137- var architecture = ImageClassificationTrainer . Architecture . InceptionV3 ;
138- Console . WriteLine ( $ "Using algorithm { architecture } ") ;
139- TrainModel ( architecture ) ;
140- TestClassifier ( ) ;
141-
142- architecture = ImageClassificationTrainer . Architecture . MobilenetV2 ;
143- Console . WriteLine ( $ "Using algorithm { architecture } ") ;
144- TrainModel ( architecture ) ;
145- TestClassifier ( ) ;
146-
147- architecture = ImageClassificationTrainer . Architecture . ResnetV2101 ;
148- Console . WriteLine ( $ "Using algorithm { architecture } ") ;
149- TrainModel ( architecture ) ;
150- TestClassifier ( ) ;
151-
152- architecture = ImageClassificationTrainer . Architecture . ResnetV250 ;
153- Console . WriteLine ( $ "Using algorithm { architecture } ") ;
154- TrainModel ( architecture ) ;
155- TestClassifier ( ) ;
113+ var architectures = new [ ] { ImageClassificationTrainer . Architecture . InceptionV3 , ImageClassificationTrainer . Architecture . MobilenetV2 , ImageClassificationTrainer . Architecture . ResnetV2101 , ImageClassificationTrainer . Architecture . ResnetV250 } ;
114+ var epochs = new [ ] { 50 , 100 , 200 , 400 } ;
115+
116+ var results = new Dictionary < ( ImageClassificationTrainer . Architecture arch , int epoch ) , ( ITransformer model , IReadOnlyList < TrainCatalogBase . CrossValidationResult < MulticlassClassificationMetrics > > metrics ) > ( ) ;
117+
118+ foreach ( var arch in architectures )
119+ {
120+ foreach ( var epoch in epochs )
121+ {
122+ Console . WriteLine ( $ "Using architecture { arch } , epochs { epoch } .") ;
123+ StopAllOutput ( ) ;
124+ results [ ( arch , epoch ) ] = TrainModel ( arch , epoch ) ;
125+ RestoreAllOutput ( ) ;
126+ TestClassifier ( results [ ( arch , epoch ) ] . model ) ;
127+ }
128+ }
129+
130+ foreach ( var arch in architectures )
131+ {
132+ foreach ( var epoch in epochs )
133+ {
134+ Console . WriteLine ( $ "Using architecture { arch } , epochs { epoch } .") ;
135+ CalculateAndPrintAverageMetrics ( results [ ( arch , epoch ) ] . metrics ) ;
136+ TestClassifier ( results [ ( arch , epoch ) ] . model ) ;
137+
138+
139+ }
140+ }
141+ }
142+
143+ static void StopAllOutput ( )
144+ {
145+ outBack = Console . Out ;
146+ Console . SetOut ( TextWriter . Null ) ;
147+ errBack = Console . Error ;
148+ Console . SetError ( TextWriter . Null ) ;
149+ }
150+
151+ static void RestoreAllOutput ( )
152+ {
153+ Console . SetOut ( outBack ) ;
154+ Console . SetError ( errBack ) ;
156155 }
157156}
158157
0 commit comments