Skip to content

Commit 3b7e0ca

Browse files
committed
[ML] RegressionIT: Fix hyperparameters for regression tests and unmute the test (elastic#135541)
This PR fixes the flaky test muted in elastic#93228 by fixing hyperparameters to the values that always work. Since the test is for alias fields and not for the training algorithm, fixing the hyperparameters is not dangerous. Closes elastic#93228
1 parent 7f9ba0f commit 3b7e0ca

File tree

1 file changed

+25
-28
lines changed
  • x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration

1 file changed

+25
-28
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,8 @@
2727
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
2828
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
2929
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
30-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
3130
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
3231
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
33-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
34-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.Hyperparameters;
35-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
3632
import org.hamcrest.Matchers;
3733
import org.junit.After;
3834

@@ -540,7 +536,6 @@ public void testWithDatastream() throws Exception {
540536
);
541537
}
542538

543-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/93228")
544539
public void testAliasFields() throws Exception {
545540
// The goal of this test is to assert alias fields are included in the analytics job.
546541
// We have a simple dataset with two integer fields: field_1 and field_2.
@@ -585,10 +580,32 @@ public void testAliasFields() throws Exception {
585580
// Very infrequently this test may fail as the algorithm underestimates the
586581
// required number of trees for this simple problem. This failure is irrelevant
587582
// for non-trivial real-world problem and improving estimation of the number of trees
588-
// would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed.
583+
// would introduce unnecessary overhead. Hence, to reduce the noise from this test we fix the seed
584+
// and use the hyperparameters that are known to work.
589585
long seed = 1000L; // fix seed
590586

591-
Regression regression = new Regression("field_2", BoostedTreeParams.builder().build(), null, 90.0, seed, null, null, null, null);
587+
Regression regression = new Regression(
588+
"field_2",
589+
BoostedTreeParams.builder()
590+
.setDownsampleFactor(0.7520841625652861)
591+
.setAlpha(547.9095715556235)
592+
.setLambda(3.3008189603590044)
593+
.setGamma(1.6082763366825203)
594+
.setSoftTreeDepthLimit(4.733224114945455)
595+
.setSoftTreeDepthTolerance(0.15)
596+
.setEta(0.12371209659057758)
597+
.setEtaGrowthRatePerTree(1.0618560482952888)
598+
.setMaxTrees(30)
599+
.setFeatureBagFraction(0.8)
600+
.build(),
601+
null,
602+
90.0,
603+
seed,
604+
null,
605+
null,
606+
null,
607+
null
608+
);
592609
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder().setId(jobId)
593610
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null, Collections.emptyMap()))
594611
.setDest(new DataFrameAnalyticsDest(destIndex, null))
@@ -604,19 +621,6 @@ public void testAliasFields() throws Exception {
604621

605622
waitUntilAnalyticsIsStopped(jobId);
606623

607-
// obtain addition information for investigation of #90599
608-
String modelId = getModelId(jobId);
609-
TrainedModelMetadata modelMetadata = getModelMetadata(modelId);
610-
assertThat(modelMetadata.getHyperparameters().size(), greaterThan(0));
611-
StringBuilder hyperparameters = new StringBuilder(); // used to investigate #90599
612-
for (Hyperparameters hyperparameter : modelMetadata.getHyperparameters()) {
613-
hyperparameters.append(hyperparameter.hyperparameterName).append(": ").append(hyperparameter.value).append("\n");
614-
}
615-
TrainedModelDefinition modelDefinition = getModelDefinition(modelId);
616-
Ensemble ensemble = (Ensemble) modelDefinition.getTrainedModel();
617-
int numberTrees = ensemble.getModels().size();
618-
619-
StringBuilder targetsPredictions = new StringBuilder(); // used to investigate #90599
620624
assertResponse(prepareSearch(sourceIndex).setSize(totalDocCount), sourceData -> {
621625
double predictionErrorSum = 0.0;
622626
for (SearchHit hit : sourceData.getHits()) {
@@ -629,19 +633,12 @@ public void testAliasFields() throws Exception {
629633
int featureValue = (int) destDoc.get("field_1");
630634
double predictionValue = (double) resultsObject.get(predictionField);
631635
predictionErrorSum += Math.abs(predictionValue - 2 * featureValue);
632-
633-
// collect the log of targets and predictions for debugging #90599
634-
targetsPredictions.append(2 * featureValue).append(", ").append(predictionValue).append("\n");
635636
}
636637
// We assert on the mean prediction error in order to reduce the probability
637638
// the test fails compared to asserting on the prediction of each individual doc.
638639
double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length;
639640
String str = "Failure: failed for seed %d inferenceEntityId %s numberTrees %d\n";
640-
assertThat(
641-
Strings.format(str, seed, modelId, numberTrees) + targetsPredictions + hyperparameters,
642-
meanPredictionError,
643-
lessThanOrEqualTo(3.0)
644-
);
641+
assertThat(meanPredictionError, lessThanOrEqualTo(3.0));
645642
});
646643

647644
assertProgressComplete(jobId);

0 commit comments

Comments
 (0)