Skip to content

Commit 4cf85ec

Browse files
committed
Almost finished
1 parent 74827de commit 4cf85ec

File tree

2 files changed

+77
-78
lines changed

2 files changed

+77
-78
lines changed

CSharp/ML/LegoColorIdentifier2/LegoColorIdentifier.csproj renamed to CSharp/ML/LegoColorIdentifier2/LegoColorIdentifier2.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<PackageReference Include="Microsoft.ML" Version="1.5.2" />
1010
<PackageReference Include="Microsoft.ML.ImageAnalytics" Version="1.5.2" />
1111
<PackageReference Include="Microsoft.ML.Vision" Version="1.5.2" />
12-
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" />
12+
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.3.1" />
1313
</ItemGroup>
1414

1515
</Project>

CSharp/ML/LegoColorIdentifier2/Program.cs

Lines changed: 76 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ public class Program
1111
static readonly string inputDataDirectoryPath = Path.Combine(Environment.CurrentDirectory, "..", "pieces");
1212
static readonly string outputModelFilePath = Path.Combine(Environment.CurrentDirectory, "model.zip");
1313
static MLContext mlContext = new MLContext(seed: 1);
14-
static ITransformer mlModel;
14+
private static TextWriter outBack;
15+
private static TextWriter errBack;
1516

1617
public class ModelInput
1718
{
@@ -24,11 +25,9 @@ public class ModelOutput
2425
public String PredictedLabel { get; set; }
2526
}
2627

27-
static void TrainModel(ImageClassificationTrainer.Architecture architecture)
28+
static (ITransformer mlModel, IReadOnlyList<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> evaluation) TrainModel(ImageClassificationTrainer.Architecture architecture, int epoch)
2829
{
2930
// To suppress errors from the TensorFlow library, set $env:TF_CPP_MIN_LOG_LEVEL = 2
30-
31-
// Create the input dataset
3231
var inputs = new List<ModelInput>();
3332
foreach (var subDir in Directory.GetDirectories(inputDataDirectoryPath))
3433
{
@@ -38,72 +37,53 @@ static void TrainModel(ImageClassificationTrainer.Architecture architecture)
3837
}
3938
}
4039
var trainingDataView = mlContext.Data.LoadFromEnumerable<ModelInput>(inputs);
41-
// Create training pipeline
4240
var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Label")
4341
.Append(mlContext.Transforms.LoadRawImageBytes("ImageSource_featurized", null, "ImageSource"))
4442
.Append(mlContext.Transforms.CopyColumns("Features", "ImageSource_featurized"));
4543
var trainer = mlContext.MulticlassClassification.Trainers.ImageClassification(
4644
new ImageClassificationTrainer.Options()
4745
{
4846
Arch = architecture,
49-
LabelColumnName = "Label",
47+
Epoch = epoch,
5048
FeatureColumnName = "Features",
49+
LabelColumnName = "Label",
5150
})
5251
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
5352
IEstimator<ITransformer> trainingPipeline = dataProcessPipeline.Append(trainer);
54-
// Create the model
55-
mlModel = trainingPipeline.Fit(trainingDataView);
56-
Evaluate(mlContext, trainingDataView, trainingPipeline);
53+
var mlModel = trainingPipeline.Fit(trainingDataView);
54+
var evaluation = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "Label");
55+
return (mlModel, evaluation);
5756
}
5857

59-
static ModelOutput Classify(string filePath)
58+
static ModelOutput Classify(PredictionEngine<ModelInput, ModelOutput> predEngine, string filePath)
6059
{
61-
// Create input to classify
6260
ModelInput input = new ModelInput() { ImageSource = filePath };
63-
// Load model and predict
64-
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
6561
return predEngine.Predict(input);
6662
}
6763

68-
static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
64+
static Dictionary<string, (double Avg, double StdDev)> CalculateAndPrintAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
6965
{
70-
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
71-
var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "Label");
72-
PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
66+
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
67+
68+
var retVal = new Dictionary<string, (double Avg, double StdDev)>();
69+
70+
retVal["MicroAccuracy"] = CalculateAverageMetrics(metricsInMultipleFolds.Select(m => m.MicroAccuracy));
71+
retVal["MacroAccuracy"] = CalculateAverageMetrics(metricsInMultipleFolds.Select(m => m.MacroAccuracy));
72+
retVal["LogLoss"] = CalculateAverageMetrics(metricsInMultipleFolds.Select(m => m.LogLoss));
73+
retVal["LogLossReduction"] = CalculateAverageMetrics(metricsInMultipleFolds.Select(m => m.LogLossReduction));
74+
75+
Console.WriteLine($"Avg. MicroAccuracy (Std. Dev): {retVal["MicroAccuracy"].Avg:0.###} ({retVal["MicroAccuracy"].StdDev:#.###})");
76+
Console.WriteLine($"Avg. MacroAccuracy (Std. Dev): {retVal["MacroAccuracy"].Avg:0.###} ({retVal["MacroAccuracy"].StdDev:#.###})");
77+
Console.WriteLine($"Avg. LogLoss (Std. Dev): {retVal["LogLoss"].Avg:#.###} ({retVal["LogLoss"].StdDev:#.###})");
78+
Console.WriteLine($"Avg. LogLossReduction (Std. Dev): {retVal["LogLossReduction"].Avg:#.###} ({retVal["LogLossReduction"].StdDev:#.###})");
79+
80+
return retVal;
7381
}
7482

75-
static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
83+
static (double, double) CalculateAverageMetrics(IEnumerable<double> metricValues)
7684
{
77-
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
78-
79-
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
80-
var microAccuracyAverage = microAccuracyValues.Average();
81-
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
82-
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
83-
84-
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
85-
var macroAccuracyAverage = macroAccuracyValues.Average();
86-
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
87-
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
88-
89-
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
90-
var logLossAverage = logLossValues.Average();
91-
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
92-
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
93-
94-
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
95-
var logLossReductionAverage = logLossReductionValues.Average();
96-
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
97-
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
98-
99-
Console.WriteLine($"*************************************************************************************************************");
100-
Console.WriteLine($"* Metrics for Multi-class Classification model ");
101-
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
102-
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
103-
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
104-
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
105-
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
106-
Console.WriteLine($"*************************************************************************************************************");
85+
return (metricValues.Average(),
86+
CalculateStandardDeviation(metricValues));
10787
}
10888

10989
static double CalculateStandardDeviation(IEnumerable<double> values)
@@ -114,45 +94,64 @@ static double CalculateStandardDeviation(IEnumerable<double> values)
11494
return standardDeviation;
11595
}
11696

117-
static double CalculateConfidenceInterval95(IEnumerable<double> values)
97+
static void TestClassifier(ITransformer model)
11898
{
119-
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
120-
return confidenceInterval95;
121-
}
99+
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(model);
122100

123-
static void TestClassifier()
124-
{
125-
var result = Classify(Path.Combine(Environment.CurrentDirectory, "Black.jpg"));
101+
var result = Classify(predEngine, Path.Combine(Environment.CurrentDirectory, "Black.jpg"));
126102
Console.WriteLine($"Testing with black piece. Prediction: {result.PredictedLabel}.");
127-
result = Classify(Path.Combine(Environment.CurrentDirectory, "Blue.jpg"));
103+
result = Classify(predEngine, Path.Combine(Environment.CurrentDirectory, "Blue.jpg"));
128104
Console.WriteLine($"Testing with blue piece. Prediction: {result.PredictedLabel}.");
129-
result = Classify(Path.Combine(Environment.CurrentDirectory, "Green.jpg"));
105+
result = Classify(predEngine, Path.Combine(Environment.CurrentDirectory, "Green.jpg"));
130106
Console.WriteLine($"Testing with green piece. Prediction: {result.PredictedLabel}.");
131-
result = Classify(Path.Combine(Environment.CurrentDirectory, "Yellow.jpg"));
107+
result = Classify(predEngine, Path.Combine(Environment.CurrentDirectory, "Yellow.jpg"));
132108
Console.WriteLine($"Testing with yellow piece. Prediction: {result.PredictedLabel}.");
133109
}
134110

135111
static void Main()
136112
{
137-
var architecture = ImageClassificationTrainer.Architecture.InceptionV3;
138-
Console.WriteLine($"Using algorithm {architecture}");
139-
TrainModel(architecture);
140-
TestClassifier();
141-
142-
architecture = ImageClassificationTrainer.Architecture.MobilenetV2;
143-
Console.WriteLine($"Using algorithm {architecture}");
144-
TrainModel(architecture);
145-
TestClassifier();
146-
147-
architecture = ImageClassificationTrainer.Architecture.ResnetV2101;
148-
Console.WriteLine($"Using algorithm {architecture}");
149-
TrainModel(architecture);
150-
TestClassifier();
151-
152-
architecture = ImageClassificationTrainer.Architecture.ResnetV250;
153-
Console.WriteLine($"Using algorithm {architecture}");
154-
TrainModel(architecture);
155-
TestClassifier();
113+
var architectures = new []{ ImageClassificationTrainer.Architecture.InceptionV3, ImageClassificationTrainer.Architecture.MobilenetV2, ImageClassificationTrainer.Architecture.ResnetV2101, ImageClassificationTrainer.Architecture.ResnetV250 };
114+
var epochs = new[] { 50, 100, 200, 400 };
115+
116+
var results = new Dictionary<(ImageClassificationTrainer.Architecture arch, int epoch), (ITransformer model, IReadOnlyList<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> metrics)>();
117+
118+
foreach(var arch in architectures)
119+
{
120+
foreach(var epoch in epochs)
121+
{
122+
Console.WriteLine($"Using architecture {arch}, epochs {epoch}.");
123+
StopAllOutput();
124+
results[(arch, epoch)] = TrainModel(arch, epoch);
125+
RestoreAllOutput();
126+
TestClassifier(results[(arch, epoch)].model);
127+
}
128+
}
129+
130+
foreach (var arch in architectures)
131+
{
132+
foreach (var epoch in epochs)
133+
{
134+
Console.WriteLine($"Using architecture {arch}, epochs {epoch}.");
135+
CalculateAndPrintAverageMetrics(results[(arch, epoch)].metrics);
136+
TestClassifier(results[(arch, epoch)].model);
137+
138+
139+
}
140+
}
141+
}
142+
143+
static void StopAllOutput()
144+
{
145+
outBack = Console.Out;
146+
Console.SetOut(TextWriter.Null);
147+
errBack = Console.Error;
148+
Console.SetError(TextWriter.Null);
149+
}
150+
151+
static void RestoreAllOutput()
152+
{
153+
Console.SetOut(outBack);
154+
Console.SetError(errBack);
156155
}
157156
}
158157

0 commit comments

Comments
 (0)