Skip to content

Commit 2aaa57f

Browse files
bruno-oliveiramarkpollack
authored andcommitted
Ensure matching enum id to value for multilingual cohere embedding model
1 parent 1453198 commit 2aaa57f

File tree

5 files changed

+72
-62
lines changed

5 files changed

+72
-62
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java

Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
package org.springframework.ai.bedrock.cohere.api;
1818

19-
// @formatter:off
20-
2119
import java.time.Duration;
2220
import java.util.List;
2321

@@ -33,47 +31,49 @@
3331
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingResponse;
3432

3533
/**
36-
* Cohere Embedding API.
37-
* <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html#model-parameters-embed">AWS Bedrock Cohere Embedding API</a>
38-
* Based on the <a href="https://docs.cohere.com/reference/embed">Cohere Embedding API</a>
34+
* Cohere Embedding API. <a href=
35+
* "https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html#model-parameters-embed">AWS
36+
* Bedrock Cohere Embedding API</a> Based on the
37+
* <a href="https://docs.cohere.com/reference/embed">Cohere Embedding API</a>
3938
*
4039
* @author Christian Tzolov
4140
* @author Wei Jiang
4241
* @since 0.8.0
4342
*/
44-
public class CohereEmbeddingBedrockApi extends
45-
AbstractBedrockApi<CohereEmbeddingRequest, CohereEmbeddingResponse, CohereEmbeddingResponse> {
43+
public class CohereEmbeddingBedrockApi
44+
extends AbstractBedrockApi<CohereEmbeddingRequest, CohereEmbeddingResponse, CohereEmbeddingResponse> {
4645

4746
/**
48-
* Create a new CohereEmbeddingBedrockApi instance using the default credentials provider chain, the default object
49-
* mapper, default temperature and topP values.
50-
*
51-
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models.
47+
* Create a new CohereEmbeddingBedrockApi instance using the default credentials
48+
* provider chain, the default object mapper, default temperature and topP values.
49+
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the
50+
* supported models.
5251
* @param region The AWS region to use.
5352
*/
5453
public CohereEmbeddingBedrockApi(String modelId, String region) {
5554
super(modelId, region);
5655
}
5756

5857
/**
59-
* Create a new CohereEmbeddingBedrockApi instance using the provided credentials provider, region and object
60-
* mapper.
61-
*
62-
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models.
58+
* Create a new CohereEmbeddingBedrockApi instance using the provided credentials
59+
* provider, region and object mapper.
60+
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the
61+
* supported models.
6362
* @param credentialsProvider The credentials provider to connect to AWS.
6463
* @param region The AWS region to use.
65-
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
64+
* @param objectMapper The object mapper to use for JSON serialization and
65+
* deserialization.
6666
*/
6767
public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
6868
ObjectMapper objectMapper) {
6969
super(modelId, credentialsProvider, region, objectMapper);
7070
}
7171

7272
/**
73-
* Create a new CohereEmbeddingBedrockApi instance using the default credentials provider chain, the default object
74-
* mapper, default temperature and topP values.
75-
*
76-
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models.
73+
* Create a new CohereEmbeddingBedrockApi instance using the default credentials
74+
* provider chain, the default object mapper, default temperature and topP values.
75+
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the
76+
* supported models.
7777
* @param region The AWS region to use.
7878
* @param timeout The timeout to use.
7979
*/
@@ -82,13 +82,14 @@ public CohereEmbeddingBedrockApi(String modelId, String region, Duration timeout
8282
}
8383

8484
/**
85-
* Create a new CohereEmbeddingBedrockApi instance using the provided credentials provider, region and object
86-
* mapper.
87-
*
88-
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models.
85+
* Create a new CohereEmbeddingBedrockApi instance using the provided credentials
86+
* provider, region and object mapper.
87+
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the
88+
* supported models.
8989
* @param credentialsProvider The credentials provider to connect to AWS.
9090
* @param region The AWS region to use.
91-
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
91+
* @param objectMapper The object mapper to use for JSON serialization and
92+
* deserialization.
9293
* @param timeout The timeout to use.
9394
*/
9495
public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region,
@@ -97,13 +98,14 @@ public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credenti
9798
}
9899

99100
/**
100-
* Create a new CohereEmbeddingBedrockApi instance using the provided credentials provider, region and object
101-
* mapper.
102-
*
103-
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the supported models.
101+
* Create a new CohereEmbeddingBedrockApi instance using the provided credentials
102+
* provider, region and object mapper.
103+
* @param modelId The model id to use. See the {@link CohereEmbeddingModel} for the
104+
* supported models.
104105
* @param credentialsProvider The credentials provider to connect to AWS.
105106
* @param region The AWS region to use.
106-
* @param objectMapper The object mapper to use for JSON serialization and deserialization.
107+
* @param objectMapper The object mapper to use for JSON serialization and
108+
* deserialization.
107109
* @param timeout The timeout to use.
108110
*/
109111
public CohereEmbeddingBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
@@ -117,13 +119,15 @@ public CohereEmbeddingResponse embedding(CohereEmbeddingRequest request) {
117119
}
118120

119121
/**
120-
* Cohere Embedding model ids. https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
122+
* Cohere Embedding model ids.
123+
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
121124
*/
122125
public enum CohereEmbeddingModel {
126+
123127
/**
124128
* cohere.embed-multilingual-v3
125129
*/
126-
COHERE_EMBED_MULTILINGUAL_V1("cohere.embed-multilingual-v3"),
130+
COHERE_EMBED_MULTILINGUAL_V3("cohere.embed-multilingual-v3"),
127131
/**
128132
* cohere.embed-english-v3
129133
*/
@@ -147,29 +151,29 @@ public String id() {
147151
/**
148152
* The Cohere Embed model request.
149153
*
150-
* @param texts An array of strings for the model to embed. For optimal performance, we recommend reducing the
151-
* length of each text to less than 512 tokens. 1 token is about 4 characters.
152-
* @param inputType Prepends special tokens to differentiate each type from one another. You should not mix
153-
* different types together, except when mixing types for search and retrieval. In this case, embed your corpus
154-
* with the search_document type and embedded queries with type search_query type.
155-
* @param truncate Specifies how the API handles inputs longer than the maximum token length. If you specify LEFT or
156-
* RIGHT, the model discards the input until the remaining input is exactly the maximum input token length for the
157-
* model.
154+
* @param texts An array of strings for the model to embed. For optimal performance,
155+
* we recommend reducing the length of each text to less than 512 tokens. 1 token is
156+
* about 4 characters.
157+
* @param inputType Prepends special tokens to differentiate each type from one
158+
* another. You should not mix different types together, except when mixing types for
159+
* search and retrieval. In this case, embed your corpus with the search_document type
160+
* and embedded queries with type search_query type.
161+
* @param truncate Specifies how the API handles inputs longer than the maximum token
162+
* length. If you specify LEFT or RIGHT, the model discards the input until the
163+
* remaining input is exactly the maximum input token length for the model.
158164
*/
159165
@JsonInclude(Include.NON_NULL)
160-
public record CohereEmbeddingRequest(
161-
@JsonProperty("texts") List<String> texts,
162-
@JsonProperty("input_type") InputType inputType,
163-
@JsonProperty("truncate") Truncate truncate) {
166+
public record CohereEmbeddingRequest(@JsonProperty("texts") List<String> texts,
167+
@JsonProperty("input_type") InputType inputType, @JsonProperty("truncate") Truncate truncate) {
164168

165169
/**
166170
* Cohere Embedding API input types.
167171
*/
168172
public enum InputType {
169173

170174
/**
171-
* In search use-cases, use search_document when you encode documents for embeddings that you store in a
172-
* vector database.
175+
* In search use-cases, use search_document when you encode documents for
176+
* embeddings that you store in a vector database.
173177
*/
174178
@JsonProperty("search_document")
175179
SEARCH_DOCUMENT,
@@ -188,12 +192,17 @@ public enum InputType {
188192
*/
189193
@JsonProperty("clustering")
190194
CLUSTERING
195+
191196
}
192197

193198
/**
194-
* Specifies how the API handles inputs longer than the maximum token length. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
199+
* Specifies how the API handles inputs longer than the maximum token length.
200+
* Passing START will discard the start of the input. END will discard the end of
201+
* the input. In both cases, input is discarded until the remaining input is
202+
* exactly the maximum input token length for the model.
195203
*/
196204
public enum Truncate {
205+
197206
/**
198207
* Returns an error when the input exceeds the maximum input token length.
199208
*/
@@ -206,29 +215,30 @@ public enum Truncate {
206215
* (default) Discards the end of the input.
207216
*/
208217
END
218+
209219
}
210220
}
211221

212222
/**
213223
* Cohere Embedding response.
214224
*
215225
* @param id An identifier for the response.
216-
* @param embeddings An array of embeddings, where each embedding is an array of floats with 1024 elements. The
217-
* length of the embeddings array will be the same as the length of the original texts array.
218-
* @param texts An array containing the text entries for which embeddings were returned.
226+
* @param embeddings An array of embeddings, where each embedding is an array of
227+
* floats with 1024 elements. The length of the embeddings array will be the same as
228+
* the length of the original texts array.
229+
* @param texts An array containing the text entries for which embeddings were
230+
* returned.
219231
* @param responseType The type of the response. The value is always embeddings.
220-
* @param amazonBedrockInvocationMetrics Bedrock invocation metrics. Currently bedrock doesn't return
221-
* invocationMetrics for the cohere embedding model.
232+
* @param amazonBedrockInvocationMetrics Bedrock invocation metrics. Currently bedrock
233+
* doesn't return invocationMetrics for the cohere embedding model.
222234
*/
223235
@JsonInclude(Include.NON_NULL)
224-
public record CohereEmbeddingResponse(
225-
@JsonProperty("id") String id,
226-
@JsonProperty("embeddings") List<float[]> embeddings,
227-
@JsonProperty("texts") List<String> texts,
236+
public record CohereEmbeddingResponse(@JsonProperty("id") String id,
237+
@JsonProperty("embeddings") List<float[]> embeddings, @JsonProperty("texts") List<String> texts,
228238
@JsonProperty("response_type") String responseType,
229-
// For future use: Currently bedrock doesn't return invocationMetrics for the cohere embedding model.
239+
// For future use: Currently bedrock doesn't return invocationMetrics for the
240+
// cohere embedding model.
230241
@JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) {
231242
}
232243

233244
}
234-
// @formatter:on

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ public static class TestConfiguration {
169169

170170
@Bean
171171
public CohereEmbeddingBedrockApi cohereEmbeddingApi() {
172-
return new CohereEmbeddingBedrockApi(CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id(),
172+
return new CohereEmbeddingBedrockApi(CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3.id(),
173173
EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(),
174174
Duration.ofMinutes(2));
175175
}

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
public class CohereEmbeddingBedrockApiIT {
4040

4141
CohereEmbeddingBedrockApi api = new CohereEmbeddingBedrockApi(
42-
CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id(), EnvironmentVariableCredentialsProvider.create(),
42+
CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3.id(), EnvironmentVariableCredentialsProvider.create(),
4343
Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2));
4444

4545
@Test

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingProperties.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class BedrockCohereEmbeddingProperties {
4343
* Bedrock Cohere Embedding generative name. Defaults to
4444
* 'cohere.embed-multilingual-v3'.
4545
*/
46-
private String model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id();
46+
private String model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3.id();
4747

4848
@NestedConfigurationProperty
4949
private BedrockCohereEmbeddingOptions options = BedrockCohereEmbeddingOptions.builder()

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class BedrockCohereEmbeddingAutoConfigurationIT {
4444

4545
private final ApplicationContextRunner contextRunner = BedrockTestUtils.getContextRunner()
4646
.withPropertyValues("spring.ai.bedrock.cohere.embedding.enabled=true",
47-
"spring.ai.bedrock.cohere.embedding.model=" + CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id(),
47+
"spring.ai.bedrock.cohere.embedding.model=" + CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3.id(),
4848
"spring.ai.bedrock.cohere.embedding.options.inputType=SEARCH_DOCUMENT",
4949
"spring.ai.bedrock.cohere.embedding.options.truncate=NONE")
5050
.withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class));

0 commit comments

Comments
 (0)