|
36 | 36 | import org.elasticsearch.inference.ModelConfigurations; |
37 | 37 | import org.elasticsearch.inference.SimilarityMeasure; |
38 | 38 | import org.elasticsearch.inference.TaskType; |
| 39 | +import org.elasticsearch.rest.RestStatus; |
39 | 40 | import org.elasticsearch.test.ESTestCase; |
40 | 41 | import org.elasticsearch.threadpool.ThreadPool; |
41 | 42 | import org.elasticsearch.xcontent.ParseField; |
|
47 | 48 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; |
48 | 49 | import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; |
49 | 50 | import org.elasticsearch.xpack.core.ml.MachineLearningField; |
| 51 | +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; |
50 | 52 | import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; |
51 | 53 | import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; |
52 | 54 | import org.elasticsearch.xpack.core.ml.action.InferModelAction; |
53 | 55 | import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; |
54 | 56 | import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; |
| 57 | +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; |
55 | 58 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; |
56 | 59 | import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; |
| 60 | +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; |
57 | 61 | import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; |
58 | 62 | import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; |
59 | 63 | import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; |
@@ -1870,6 +1874,49 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { |
1870 | 1874 | } |
1871 | 1875 | } |
1872 | 1876 |
|
| 1877 | + public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException { |
| 1878 | + var model = new ElserInternalModel( |
| 1879 | + "inference_id", |
| 1880 | + TaskType.SPARSE_EMBEDDING, |
| 1881 | + "elasticsearch", |
| 1882 | + new ElserInternalServiceSettings( |
| 1883 | + new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null) |
| 1884 | + ), |
| 1885 | + new ElserMlNodeTaskSettings(), |
| 1886 | + null |
| 1887 | + ); |
| 1888 | + |
| 1889 | + var client = mock(Client.class); |
| 1890 | + when(client.threadPool()).thenReturn(threadPool); |
| 1891 | + |
| 1892 | + doAnswer(invocationOnMock -> { |
| 1893 | + ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2); |
| 1894 | + var builder = GetTrainedModelsAction.Response.builder(); |
| 1895 | + builder.setModels(List.of(mock(TrainedModelConfig.class))); |
| 1896 | + builder.setTotalCount(1); |
| 1897 | + |
| 1898 | + listener.onResponse(builder.build()); |
| 1899 | + return Void.TYPE; |
| 1900 | + }).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any()); |
| 1901 | + |
| 1902 | + doAnswer(invocationOnMock -> { |
| 1903 | + ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2); |
| 1904 | + listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT)); |
| 1905 | + return Void.TYPE; |
| 1906 | + }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any()); |
| 1907 | + |
| 1908 | + try (var service = createService(client)) { |
| 1909 | + var actionListener = new PlainActionFuture<Boolean>(); |
| 1910 | + service.start(model, TimeValue.timeValueSeconds(30), actionListener); |
| 1911 | + var exception = expectThrows( |
| 1912 | + ElasticsearchStatusException.class, |
| 1913 | + () -> actionListener.actionGet(TimeValue.timeValueSeconds(30)) |
| 1914 | + ); |
| 1915 | + |
| 1916 | + assertThat(exception.getMessage(), is("failed")); |
| 1917 | + } |
| 1918 | + } |
| 1919 | + |
1873 | 1920 | private ElasticsearchInternalService createService(Client client) { |
1874 | 1921 | var cs = mock(ClusterService.class); |
1875 | 1922 | var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); |
|
0 commit comments