Skip to content

Commit 037ddaa

Browse files
authored
[ML] Remove SageMaker Elastic updates (elastic#131301)
Rather than silently drop the payload, throw a validation error when Users try to send task settings in the update payload for SageMaker inference with the Elastic API.
1 parent 0cf275e commit 037ddaa

File tree

7 files changed

+44
-16
lines changed

7 files changed

+44
-16
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public SageMakerModel override(Map<String, Object> taskSettingsOverride) {
116116
getConfigurations(),
117117
getSecrets(),
118118
serviceSettings,
119-
taskSettings.updatedTaskSettings(taskSettingsOverride),
119+
taskSettings.override(taskSettingsOverride),
120120
awsSecretSettings
121121
);
122122
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,21 @@ public boolean isEmpty() {
7171
@Override
7272
public SageMakerTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
7373
var validationException = new ValidationException();
74-
7574
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException);
75+
validationException.throwIfValidationErrorsExist();
76+
77+
return override(updateTaskSettings);
78+
}
7679

80+
public SageMakerTaskSettings override(Map<String, Object> newSettings) {
81+
var validationException = new ValidationException();
82+
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.override(newSettings), validationException);
7783
validationException.throwIfValidationErrorsExist();
7884

85+
return override(updateTaskSettings);
86+
}
87+
88+
private SageMakerTaskSettings override(SageMakerTaskSettings updateTaskSettings) {
7989
var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP)
8090
? apiTaskSettings
8191
: updateTaskSettings.apiTaskSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,8 @@ default boolean isFragment() {
6868

6969
@Override
7070
SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings);
71+
72+
default SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
73+
return updatedTaskSettings(newSettings);
74+
}
7175
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ default SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest re
8888

8989
@Override
9090
default SageMakerElasticTaskSettings apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
91-
if (taskSettings != null && (taskSettings.isEmpty() == false)) {
92-
validationException.addValidationError(
93-
InferenceAction.Request.TASK_SETTINGS.getPreferredName()
94-
+ " is only supported during the inference request and cannot be stored in the inference endpoint."
95-
);
96-
}
9791
return SageMakerElasticTaskSettings.empty();
9892
}
9993

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.ValidationException;
1213
import org.elasticsearch.common.io.stream.StreamInput;
1314
import org.elasticsearch.common.io.stream.StreamOutput;
1415
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.xcontent.XContentBuilder;
17+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1618
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
1719

1820
import java.io.IOException;
@@ -40,6 +42,16 @@ public boolean isEmpty() {
4042

4143
@Override
4244
public SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings) {
45+
var validationException = new ValidationException();
46+
validationException.addValidationError(
47+
InferenceAction.Request.TASK_SETTINGS.getPreferredName()
48+
+ " is only supported during the inference request and cannot be stored in the inference endpoint."
49+
);
50+
throw validationException;
51+
}
52+
53+
@Override
54+
public SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
4355
return new SageMakerElasticTaskSettings(newSettings);
4456
}
4557

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public final void testWithUnknownApiTaskSettings() {
119119
}
120120
}
121121

122-
public final void testUpdate() throws IOException {
122+
public void testUpdate() throws IOException {
123123
var taskSettings = randomApiTaskSettings();
124124
if (taskSettings != SageMakerStoredTaskSchema.NO_OP) {
125125
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import java.util.List;
1919
import java.util.Map;
2020

21-
import static org.hamcrest.Matchers.equalTo;
22-
import static org.hamcrest.Matchers.is;
21+
import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap;
22+
import static org.hamcrest.Matchers.containsString;
2323
import static org.mockito.Mockito.mock;
2424
import static org.mockito.Mockito.when;
2525

@@ -50,6 +50,7 @@ protected SageMakerModel mockModel(SageMakerElasticTaskSettings taskSettings) {
5050
return model;
5151
}
5252

53+
@Override
5354
public void testApiTaskSettings() {
5455
{
5556
var validationException = new ValidationException();
@@ -67,14 +68,21 @@ public void testApiTaskSettings() {
6768
var validationException = new ValidationException();
6869
var actualApiTaskSettings = payload.apiTaskSettings(Map.of("hello", "world"), validationException);
6970
assertTrue(actualApiTaskSettings.isEmpty());
70-
assertFalse(validationException.validationErrors().isEmpty());
71-
assertThat(
72-
validationException.validationErrors().get(0),
73-
is(equalTo("task_settings is only supported during the inference request and cannot be stored in the inference endpoint."))
74-
);
71+
assertTrue(validationException.validationErrors().isEmpty());
7572
}
7673
}
7774

75+
@Override
76+
public void testUpdate() {
77+
var taskSettings = randomApiTaskSettings();
78+
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
79+
var e = assertThrows(ValidationException.class, () -> taskSettings.updatedTaskSettings(toMap(otherTaskSettings)));
80+
assertThat(
81+
e.getMessage(),
82+
containsString("task_settings is only supported during the inference request and cannot be stored in the inference endpoint")
83+
);
84+
}
85+
7886
public void testRequestWithRequiredFields() throws Exception {
7987
var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.UNSPECIFIED);
8088
var sdkByes = payload.requestBytes(mockModel(), request);

0 commit comments

Comments
 (0)