Skip to content

Commit 9c6cf90

Browse files
Enable force inference endpoint deleting for invalid models and after stopping model deployment fails (elastic#129090)
* Enable force inference endpoint deleting for invalid models and after stopping model deployment fails * Update docs/changelog/129090.yaml --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 037ddaa commit 9c6cf90

File tree

3 files changed

+255
-3
lines changed

3 files changed

+255
-3
lines changed

docs/changelog/129090.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 129090
2+
summary: Enable force inference endpoint deleting for invalid models and after stopping
3+
model deployment fails
4+
area: Machine Learning
5+
type: enhancement
6+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.common.Strings;
2424
import org.elasticsearch.common.util.concurrent.EsExecutors;
2525
import org.elasticsearch.inference.InferenceServiceRegistry;
26+
import org.elasticsearch.inference.Model;
2627
import org.elasticsearch.inference.UnparsedModel;
2728
import org.elasticsearch.injection.guice.Inject;
2829
import org.elasticsearch.rest.RestStatus;
@@ -128,10 +129,38 @@ private void doExecuteForked(
128129
}
129130

130131
var service = serviceRegistry.getService(unparsedModel.service());
132+
Model model;
131133
if (service.isPresent()) {
132-
var model = service.get()
133-
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
134-
service.get().stop(model, listener);
134+
try {
135+
model = service.get()
136+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
137+
} catch (Exception e) {
138+
if (request.isForceDelete()) {
139+
listener.onResponse(true);
140+
return;
141+
} else {
142+
listener.onFailure(
143+
new ElasticsearchStatusException(
144+
Strings.format(
145+
"Failed to parse model configuration for inference endpoint [%s]",
146+
request.getInferenceEndpointId()
147+
),
148+
RestStatus.INTERNAL_SERVER_ERROR,
149+
e
150+
)
151+
);
152+
return;
153+
}
154+
}
155+
service.get().stop(model, listener.delegateResponse((l, e) -> {
156+
if (request.isForceDelete()) {
157+
l.onResponse(true);
158+
} else {
159+
l.onFailure(e);
160+
}
161+
}));
162+
} else if (request.isForceDelete()) {
163+
listener.onResponse(true);
135164
} else {
136165
listener.onFailure(
137166
new ElasticsearchStatusException(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import org.elasticsearch.core.TimeValue;
1818
import org.elasticsearch.inference.InferenceService;
1919
import org.elasticsearch.inference.InferenceServiceRegistry;
20+
import org.elasticsearch.inference.Model;
2021
import org.elasticsearch.inference.TaskType;
2122
import org.elasticsearch.inference.UnparsedModel;
23+
import org.elasticsearch.rest.RestStatus;
2224
import org.elasticsearch.tasks.Task;
2325
import org.elasticsearch.test.ESTestCase;
2426
import org.elasticsearch.threadpool.ThreadPool;
@@ -32,11 +34,17 @@
3234
import java.util.Optional;
3335

3436
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
37+
import static org.hamcrest.Matchers.containsString;
3538
import static org.hamcrest.Matchers.is;
3639
import static org.mockito.ArgumentMatchers.any;
3740
import static org.mockito.ArgumentMatchers.anyString;
41+
import static org.mockito.ArgumentMatchers.eq;
3842
import static org.mockito.Mockito.doAnswer;
43+
import static org.mockito.Mockito.doReturn;
44+
import static org.mockito.Mockito.doThrow;
3945
import static org.mockito.Mockito.mock;
46+
import static org.mockito.Mockito.verify;
47+
import static org.mockito.Mockito.verifyNoMoreInteractions;
4048
import static org.mockito.Mockito.when;
4149

4250
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
@@ -130,4 +138,213 @@ public void testDeletesDefaultEndpoint_WhenForceIsTrue() {
130138

131139
assertTrue(response.isAcknowledged());
132140
}
141+
142+
public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() {
143+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
144+
var serviceName = randomAlphanumericOfLength(10);
145+
var taskType = randomFrom(TaskType.values());
146+
var mockService = mock(InferenceService.class);
147+
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
148+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
149+
150+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
151+
action.masterOperation(
152+
mock(Task.class),
153+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
154+
ClusterState.EMPTY_STATE,
155+
listener
156+
);
157+
158+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
159+
assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint"));
160+
161+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
162+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
163+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
164+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
165+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
166+
}
167+
168+
public void testDeletesUnparsableEndpoint_WhenForceIsTrue() {
169+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
170+
var serviceName = randomAlphanumericOfLength(10);
171+
var taskType = randomFrom(TaskType.values());
172+
var mockService = mock(InferenceService.class);
173+
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
174+
doAnswer(invocationOnMock -> {
175+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
176+
listener.onResponse(true);
177+
return Void.TYPE;
178+
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
179+
180+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
181+
182+
action.masterOperation(
183+
mock(Task.class),
184+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
185+
ClusterState.EMPTY_STATE,
186+
listener
187+
);
188+
189+
var response = listener.actionGet(TIMEOUT);
190+
assertTrue(response.isAcknowledged());
191+
192+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
193+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
194+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
195+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
196+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
197+
}
198+
199+
private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) {
200+
doAnswer(invocationOnMock -> {
201+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
202+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
203+
return Void.TYPE;
204+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
205+
doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService)
206+
.parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
207+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
208+
}
209+
210+
public void testDeletesEndpointWithNoService_WhenForceIsTrue() {
211+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
212+
var serviceName = randomAlphanumericOfLength(10);
213+
var taskType = randomFrom(TaskType.values());
214+
mockNoService(inferenceEndpointId, serviceName, taskType);
215+
doAnswer(invocationOnMock -> {
216+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
217+
listener.onResponse(true);
218+
return Void.TYPE;
219+
}).when(mockModelRegistry).deleteModel(anyString(), any());
220+
221+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
222+
223+
action.masterOperation(
224+
mock(Task.class),
225+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
226+
ClusterState.EMPTY_STATE,
227+
listener
228+
);
229+
230+
var response = listener.actionGet(TIMEOUT);
231+
assertTrue(response.isAcknowledged());
232+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
233+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
234+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
235+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
236+
}
237+
238+
public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() {
239+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
240+
var serviceName = randomAlphanumericOfLength(10);
241+
var taskType = randomFrom(TaskType.values());
242+
mockNoService(inferenceEndpointId, serviceName, taskType);
243+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
244+
245+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
246+
247+
action.masterOperation(
248+
mock(Task.class),
249+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
250+
ClusterState.EMPTY_STATE,
251+
listener
252+
);
253+
254+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
255+
assertThat(exception.getMessage(), containsString("No service found for this inference endpoint"));
256+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
257+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
258+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
259+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
260+
}
261+
262+
private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) {
263+
doAnswer(invocationOnMock -> {
264+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
265+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
266+
return Void.TYPE;
267+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
268+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty());
269+
}
270+
271+
public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() {
272+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
273+
var serviceName = randomAlphanumericOfLength(10);
274+
var taskType = randomFrom(TaskType.values());
275+
var mockService = mock(InferenceService.class);
276+
var mockModel = mock(Model.class);
277+
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
278+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
279+
280+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
281+
action.masterOperation(
282+
mock(Task.class),
283+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
284+
ClusterState.EMPTY_STATE,
285+
listener
286+
);
287+
288+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
289+
assertThat(exception.getMessage(), containsString("Failed to stop model deployment"));
290+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
291+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
292+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
293+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
294+
verify(mockService).stop(eq(mockModel), any());
295+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
296+
}
297+
298+
public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() {
299+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
300+
var serviceName = randomAlphanumericOfLength(10);
301+
var taskType = randomFrom(TaskType.values());
302+
var mockService = mock(InferenceService.class);
303+
var mockModel = mock(Model.class);
304+
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
305+
doAnswer(invocationOnMock -> {
306+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
307+
listener.onResponse(true);
308+
return Void.TYPE;
309+
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
310+
311+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
312+
action.masterOperation(
313+
mock(Task.class),
314+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
315+
ClusterState.EMPTY_STATE,
316+
listener
317+
);
318+
319+
var response = listener.actionGet(TIMEOUT);
320+
assertTrue(response.isAcknowledged());
321+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
322+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
323+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
324+
verify(mockService).stop(eq(mockModel), any());
325+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
326+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
327+
}
328+
329+
private void mockStopDeploymentFails(
330+
String inferenceEndpointId,
331+
String serviceName,
332+
TaskType taskType,
333+
InferenceService mockService,
334+
Model mockModel
335+
) {
336+
doAnswer(invocationOnMock -> {
337+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
338+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
339+
return Void.TYPE;
340+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
341+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
342+
doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
343+
doAnswer(invocationOnMock -> {
344+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
345+
listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR));
346+
return Void.TYPE;
347+
}).when(mockService).stop(eq(mockModel), any());
348+
}
349+
133350
}

0 commit comments

Comments
 (0)