Skip to content

Commit 5dfc3fa

Browse files
committed
Final release for publishing
1 parent 4cf85ec commit 5dfc3fa

File tree

1 file changed

+2
-18
lines changed

1 file changed

+2
-18
lines changed

CSharp/ML/LegoColorIdentifier2/Program.cs

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
public class Program
1010
{
1111
static readonly string inputDataDirectoryPath = Path.Combine(Environment.CurrentDirectory, "..", "pieces");
12-
static readonly string outputModelFilePath = Path.Combine(Environment.CurrentDirectory, "model.zip");
1312
static MLContext mlContext = new MLContext(seed: 1);
1413
private static TextWriter outBack;
1514
private static TextWriter errBack;
@@ -61,7 +60,7 @@ static ModelOutput Classify(PredictionEngine<ModelInput, ModelOutput> predEngine
6160
return predEngine.Predict(input);
6261
}
6362

64-
static Dictionary<string, (double Avg, double StdDev)> CalculateAndPrintAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
63+
static void CalculateAndPrintAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
6564
{
6665
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
6766

@@ -76,14 +75,11 @@ static ModelOutput Classify(PredictionEngine<ModelInput, ModelOutput> predEngine
7675
Console.WriteLine($"Avg. MacroAccuracy (Std. Dev): {retVal["MacroAccuracy"].Avg:0.###} ({retVal["MacroAccuracy"].StdDev:#.###})");
7776
Console.WriteLine($"Avg. LogLoss (Std. Dev): {retVal["LogLoss"].Avg:#.###} ({retVal["LogLoss"].StdDev:#.###})");
7877
Console.WriteLine($"Avg. LogLossReduction (Std. Dev): {retVal["LogLossReduction"].Avg:#.###} ({retVal["LogLossReduction"].StdDev:#.###})");
79-
80-
return retVal;
8178
}
8279

8380
static (double, double) CalculateAverageMetrics(IEnumerable<double> metricValues)
8481
{
85-
return (metricValues.Average(),
86-
CalculateStandardDeviation(metricValues));
82+
return (metricValues.Average(), CalculateStandardDeviation(metricValues));
8783
}
8884

8985
static double CalculateStandardDeviation(IEnumerable<double> values)
@@ -123,23 +119,11 @@ static void Main()
123119
StopAllOutput();
124120
results[(arch, epoch)] = TrainModel(arch, epoch);
125121
RestoreAllOutput();
126-
TestClassifier(results[(arch, epoch)].model);
127-
}
128-
}
129-
130-
foreach (var arch in architectures)
131-
{
132-
foreach (var epoch in epochs)
133-
{
134-
Console.WriteLine($"Using architecture {arch}, epochs {epoch}.");
135122
CalculateAndPrintAverageMetrics(results[(arch, epoch)].metrics);
136123
TestClassifier(results[(arch, epoch)].model);
137-
138-
139124
}
140125
}
141126
}
142-
143127
static void StopAllOutput()
144128
{
145129
outBack = Console.Out;

0 commit comments

Comments
 (0)