|
23 | 23 | import org.apache.lucene.search.join.BitSetProducer; |
24 | 24 | import org.apache.lucene.search.join.QueryBitSetProducer; |
25 | 25 | import org.apache.lucene.search.join.ScoreMode; |
26 | | -import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; |
27 | 26 | import org.elasticsearch.cluster.metadata.IndexMetadata; |
28 | 27 | import org.elasticsearch.common.CheckedBiConsumer; |
29 | 28 | import org.elasticsearch.common.CheckedBiFunction; |
|
73 | 72 |
|
74 | 73 | import java.io.IOException; |
75 | 74 | import java.util.Collection; |
| 75 | +import java.util.HashMap; |
76 | 76 | import java.util.HashSet; |
77 | 77 | import java.util.LinkedHashMap; |
78 | 78 | import java.util.List; |
@@ -879,17 +879,29 @@ private MapperService mapperServiceForFieldWithModelSettings( |
879 | 879 | String searchInferenceId, |
880 | 880 | MinimalServiceSettings modelSettings |
881 | 881 | ) throws IOException { |
882 | | - String mappingParams = "type=semantic_text,inference_id=" + inferenceId; |
| 882 | + return mapperServiceForFieldWithModelSettingsAndIndexOptions(fieldName, inferenceId, searchInferenceId, modelSettings, null); |
| 883 | + } |
| 884 | + |
| 885 | + private MapperService mapperServiceForFieldWithModelSettingsAndIndexOptions( |
| 886 | + String fieldName, |
| 887 | + String inferenceId, |
| 888 | + String searchInferenceId, |
| 889 | + MinimalServiceSettings modelSettings, |
| 890 | + DenseVectorFieldMapper.IndexOptions indexOptions |
| 891 | + ) throws IOException { |
| 892 | + XContentBuilder mappingBuilder = JsonXContent.contentBuilder().startObject(); |
| 893 | + mappingBuilder.startObject("properties").startObject(fieldName).field("type", "semantic_text").field("inference_id", inferenceId); |
883 | 894 | if (searchInferenceId != null) { |
884 | | - mappingParams += ",search_inference_id=" + searchInferenceId; |
| 895 | + mappingBuilder.field("search_inference_id", searchInferenceId); |
885 | 896 | } |
| 897 | + if (indexOptions != null) { |
| 898 | + mappingBuilder.field("index_options", indexOptions); |
| 899 | + } |
| 900 | + |
| 901 | + mappingBuilder.endObject().endObject().endObject(); |
886 | 902 |
|
887 | 903 | MapperService mapperService = createMapperService(mapping(b -> {}), useLegacyFormat); |
888 | | - mapperService.merge( |
889 | | - "_doc", |
890 | | - new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(fieldName, mappingParams))), |
891 | | - MapperService.MergeReason.MAPPING_UPDATE |
892 | | - ); |
| 904 | + mapperService.merge("_doc", new CompressedXContent(Strings.toString(mappingBuilder)), MapperService.MergeReason.MAPPING_UPDATE); |
893 | 905 |
|
894 | 906 | SemanticTextField semanticTextField = new SemanticTextField( |
895 | 907 | useLegacyFormat, |
@@ -951,6 +963,105 @@ public void testExistsQueryDenseVector() throws IOException { |
951 | 963 | assertThat(existsQuery, instanceOf(ESToParentBlockJoinQuery.class)); |
952 | 964 | } |
953 | 965 |
|
| 966 | + public void testDenseVectorIndexOptions() throws IOException { |
| 967 | + final String fieldName = "field"; |
| 968 | + final String inferenceId = "test_service"; |
| 969 | + |
| 970 | + List<DenseVectorFieldMapper.IndexOptions> indexOptionsList = List.of( |
| 971 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "hnsw"))), |
| 972 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int8_hnsw"))), |
| 973 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int4_hnsw"))), |
| 974 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "bbq_hnsw"))), |
| 975 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "flat"))), |
| 976 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int8_flat"))), |
| 977 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "int4_flat"))), |
| 978 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "bbq_flat"))), |
| 979 | + DenseVectorFieldMapper.parseIndexOptions(fieldName, new HashMap<>(Map.of("type", "hnsw", "m", 32, "ef_construction", 200))) |
| 980 | + ); |
| 981 | + |
| 982 | + for (DenseVectorFieldMapper.IndexOptions indexOptions : indexOptionsList) { |
| 983 | + BiConsumer<MapperService, DenseVectorFieldMapper.IndexOptions> assertMapperService = (m, e) -> { |
| 984 | + Mapper mapper = m.mappingLookup().getMapper(fieldName); |
| 985 | + assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); |
| 986 | + SemanticTextFieldMapper semanticTextFieldMapper = (SemanticTextFieldMapper) mapper; |
| 987 | + |
| 988 | + FieldMapper fieldMapper = semanticTextFieldMapper.fieldType().getEmbeddingsField(); |
| 989 | + assertThat(fieldMapper, instanceOf(DenseVectorFieldMapper.class)); |
| 990 | + DenseVectorFieldMapper denseVectorFieldMapper = (DenseVectorFieldMapper) fieldMapper; |
| 991 | + |
| 992 | + assertThat(denseVectorFieldMapper.indexOptions(), equalTo(e)); |
| 993 | + }; |
| 994 | + |
| 995 | + MapperService floatMapperService = mapperServiceForFieldWithModelSettingsAndIndexOptions( |
| 996 | + fieldName, |
| 997 | + inferenceId, |
| 998 | + inferenceId, |
| 999 | + new MinimalServiceSettings( |
| 1000 | + TaskType.TEXT_EMBEDDING, |
| 1001 | + 1024, |
| 1002 | + SimilarityMeasure.COSINE, |
| 1003 | + DenseVectorFieldMapper.ElementType.FLOAT |
| 1004 | + ), |
| 1005 | + indexOptions |
| 1006 | + ); |
| 1007 | + assertMapperService.accept(floatMapperService, indexOptions); |
| 1008 | + } |
| 1009 | + } |
| 1010 | + |
| 1011 | + public void testDenseVectorIndexOptionsVaild() { |
| 1012 | + final String fieldName = "field"; |
| 1013 | + final String inferenceId = "test_service"; |
| 1014 | + |
| 1015 | + { |
| 1016 | + DenseVectorFieldMapper.IndexOptions indexOptions = DenseVectorFieldMapper.parseIndexOptions( |
| 1017 | + fieldName, |
| 1018 | + new HashMap<>(Map.of("type", "int8_hnsw")) |
| 1019 | + ); |
| 1020 | + MinimalServiceSettings invalidSettings = new MinimalServiceSettings( |
| 1021 | + TaskType.TEXT_EMBEDDING, |
| 1022 | + 1024, |
| 1023 | + SimilarityMeasure.L2_NORM, |
| 1024 | + DenseVectorFieldMapper.ElementType.BYTE |
| 1025 | + ); |
| 1026 | + |
| 1027 | + Exception e = expectThrows( |
| 1028 | + DocumentParsingException.class, |
| 1029 | + () -> mapperServiceForFieldWithModelSettingsAndIndexOptions( |
| 1030 | + fieldName, |
| 1031 | + inferenceId, |
| 1032 | + inferenceId, |
| 1033 | + invalidSettings, |
| 1034 | + indexOptions |
| 1035 | + ) |
| 1036 | + ); |
| 1037 | + assertThat(e.getCause().getMessage(), containsString("cannot be [byte] when using index type [int8_hnsw]")); |
| 1038 | + } |
| 1039 | + |
| 1040 | + { |
| 1041 | + DenseVectorFieldMapper.IndexOptions indexOptions = DenseVectorFieldMapper.parseIndexOptions( |
| 1042 | + fieldName, |
| 1043 | + new HashMap<>(Map.of("type", "bbq_hnsw")) |
| 1044 | + ); |
| 1045 | + MinimalServiceSettings invalidSettings = new MinimalServiceSettings( |
| 1046 | + TaskType.TEXT_EMBEDDING, |
| 1047 | + 10, |
| 1048 | + SimilarityMeasure.COSINE, |
| 1049 | + DenseVectorFieldMapper.ElementType.BYTE |
| 1050 | + ); |
| 1051 | + Exception e = expectThrows( |
| 1052 | + DocumentParsingException.class, |
| 1053 | + () -> mapperServiceForFieldWithModelSettingsAndIndexOptions( |
| 1054 | + fieldName, |
| 1055 | + inferenceId, |
| 1056 | + inferenceId, |
| 1057 | + invalidSettings, |
| 1058 | + indexOptions |
| 1059 | + ) |
| 1060 | + ); |
| 1061 | + assertThat(e.getCause().getMessage(), containsString("bbq_hnsw does not support dimensions fewer than 64")); |
| 1062 | + } |
| 1063 | + } |
| 1064 | + |
954 | 1065 | @Override |
955 | 1066 | protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneDocument fields) { |
956 | 1067 | // Until a doc is indexed, the query is rewritten as match no docs |
|
0 commit comments