Skip to content

Commit f467086

Browse files
committed
MetadataModel moved to AbstractEmbeddingModel.
Embedding models, that didn't have support for a MetadataModel got an additional constructor with MetadataModel as the second argument. I did not include it in all the other constructors since this would break code. AzureOpenAiEmbeddingModel had special handling for empty results, that got lost in the change. Not sure if that is ok. TransformersEmbeddingModel has as default, which differs from the others. Should that be changed to , which seems the most reasonable default to me and is used in the other models?
1 parent 923e09a commit f467086

File tree

16 files changed

+170
-119
lines changed

16 files changed

+170
-119
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
6363

6464
private final AzureOpenAiEmbeddingOptions defaultOptions;
6565

66-
private final MetadataMode metadataMode;
67-
6866
/**
6967
* Observation registry used for instrumentation.
7068
*/
@@ -92,30 +90,16 @@ public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode me
9290
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
9391
AzureOpenAiEmbeddingOptions options, ObservationRegistry observationRegistry) {
9492

93+
super(metadataMode);
94+
9595
Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
96-
Assert.notNull(metadataMode, "Metadata mode must not be null");
9796
Assert.notNull(options, "Options must not be null");
9897
Assert.notNull(observationRegistry, "Observation registry must not be null");
9998
this.azureOpenAiClient = azureOpenAiClient;
100-
this.metadataMode = metadataMode;
10199
this.defaultOptions = options;
102100
this.observationRegistry = observationRegistry;
103101
}
104102

105-
@Override
106-
public float[] embed(Document document) {
107-
logger.debug("Retrieving embeddings");
108-
109-
EmbeddingResponse response = this
110-
.call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), null));
111-
logger.debug("Embeddings retrieved");
112-
113-
if (CollectionUtils.isEmpty(response.getResults())) {
114-
return new float[0];
115-
}
116-
return response.getResults().get(0).getOutput();
117-
}
118-
119103
@Override
120104
public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
121105
logger.debug("Retrieving embeddings");

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
2525
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingResponse;
2626
import org.springframework.ai.document.Document;
27+
import org.springframework.ai.document.MetadataMode;
2728
import org.springframework.ai.embedding.AbstractEmbeddingModel;
2829
import org.springframework.ai.embedding.Embedding;
2930
import org.springframework.ai.embedding.EmbeddingOptions;
@@ -65,17 +66,21 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr
6566

6667
public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
6768
BedrockCohereEmbeddingOptions options) {
69+
70+
this(cohereEmbeddingBedrockApi, MetadataMode.EMBED, options);
71+
}
72+
73+
public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi, MetadataMode metadataMode,
74+
BedrockCohereEmbeddingOptions options) {
75+
76+
super(metadataMode);
77+
6878
Assert.notNull(cohereEmbeddingBedrockApi, "CohereEmbeddingBedrockApi must not be null");
6979
Assert.notNull(options, "BedrockCohereEmbeddingOptions must not be null");
7080
this.embeddingApi = cohereEmbeddingBedrockApi;
7181
this.defaultOptions = options;
7282
}
7383

74-
@Override
75-
public float[] embed(Document document) {
76-
return embed(document.getText());
77-
}
78-
7984
@Override
8085
public EmbeddingResponse call(EmbeddingRequest request) {
8186

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
2828
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse;
2929
import org.springframework.ai.document.Document;
30+
import org.springframework.ai.document.MetadataMode;
3031
import org.springframework.ai.embedding.AbstractEmbeddingModel;
3132
import org.springframework.ai.embedding.Embedding;
3233
import org.springframework.ai.embedding.EmbeddingOptions;
@@ -57,6 +58,13 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel {
5758
private InputType inputType = InputType.TEXT;
5859

5960
public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
61+
this(titanEmbeddingBedrockApi, MetadataMode.EMBED);
62+
}
63+
64+
public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi, MetadataMode metadataMode) {
65+
66+
super(metadataMode);
67+
6068
this.embeddingApi = titanEmbeddingBedrockApi;
6169
}
6270

@@ -69,11 +77,6 @@ public BedrockTitanEmbeddingModel withInputType(InputType inputType) {
6977
return this;
7078
}
7179

72-
@Override
73-
public float[] embed(Document document) {
74-
return embed(document.getText());
75-
}
76-
7780
@Override
7881
public EmbeddingResponse call(EmbeddingRequest request) {
7982
Assert.notEmpty(request.getInstructions(), "At least one text is required!");

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ public class MiniMaxEmbeddingModel extends AbstractEmbeddingModel {
6363

6464
private final MiniMaxApi miniMaxApi;
6565

66-
private final MetadataMode metadataMode;
67-
6866
/**
6967
* Observation registry used for instrumentation.
7068
*/
@@ -128,25 +126,20 @@ public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode,
128126
*/
129127
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode, MiniMaxEmbeddingOptions options,
130128
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
129+
130+
super(metadataMode);
131+
131132
Assert.notNull(miniMaxApi, "MiniMaxApi must not be null");
132-
Assert.notNull(metadataMode, "metadataMode must not be null");
133133
Assert.notNull(options, "options must not be null");
134134
Assert.notNull(retryTemplate, "retryTemplate must not be null");
135135
Assert.notNull(observationRegistry, "observationRegistry must not be null");
136136

137137
this.miniMaxApi = miniMaxApi;
138-
this.metadataMode = metadataMode;
139138
this.defaultOptions = options;
140139
this.retryTemplate = retryTemplate;
141140
this.observationRegistry = observationRegistry;
142141
}
143142

144-
@Override
145-
public float[] embed(Document document) {
146-
Assert.notNull(document, "Document must not be null");
147-
return this.embed(document.getFormattedContent(this.metadataMode));
148-
}
149-
150143
@Override
151144
public EmbeddingResponse call(EmbeddingRequest request) {
152145
MiniMaxEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
5757

5858
private final MistralAiEmbeddingOptions defaultOptions;
5959

60-
private final MetadataMode metadataMode;
61-
6260
private final MistralAiApi mistralAiApi;
6361

6462
private final RetryTemplate retryTemplate;
@@ -94,14 +92,16 @@ public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataM
9492

9593
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode,
9694
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
95+
96+
super(metadataMode);
97+
9798
Assert.notNull(mistralAiApi, "mistralAiApi must not be null");
9899
Assert.notNull(metadataMode, "metadataMode must not be null");
99100
Assert.notNull(options, "options must not be null");
100101
Assert.notNull(retryTemplate, "retryTemplate must not be null");
101102
Assert.notNull(observationRegistry, "observationRegistry must not be null");
102103

103104
this.mistralAiApi = mistralAiApi;
104-
this.metadataMode = metadataMode;
105105
this.defaultOptions = options;
106106
this.retryTemplate = retryTemplate;
107107
this.observationRegistry = observationRegistry;
@@ -174,12 +174,6 @@ private MistralAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingReque
174174
requestOptions.getEncodingFormat());
175175
}
176176

177-
@Override
178-
public float[] embed(Document document) {
179-
Assert.notNull(document, "Document must not be null");
180-
return this.embed(document.getFormattedContent(this.metadataMode));
181-
}
182-
183177
/**
184178
* Use the provided convention for reporting observation data
185179
* @param observationConvention The provided convention

models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import org.springframework.ai.chat.metadata.EmptyUsage;
3232
import org.springframework.ai.document.Document;
33+
import org.springframework.ai.document.MetadataMode;
3334
import org.springframework.ai.embedding.AbstractEmbeddingModel;
3435
import org.springframework.ai.embedding.Embedding;
3536
import org.springframework.ai.embedding.EmbeddingOptions;
@@ -72,6 +73,14 @@ public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions option
7273

7374
public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options,
7475
ObservationRegistry observationRegistry) {
76+
this(genAi, MetadataMode.EMBED, options, observationRegistry);
77+
}
78+
79+
public OCIEmbeddingModel(GenerativeAiInference genAi, MetadataMode metadataMode, OCIEmbeddingOptions options,
80+
ObservationRegistry observationRegistry) {
81+
82+
super(metadataMode);
83+
7584
Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient must not be null");
7685
Assert.notNull(options, "options must not be null");
7786
Assert.notNull(observationRegistry, "observationRegistry must not be null");
@@ -98,11 +107,6 @@ public EmbeddingResponse call(EmbeddingRequest request) {
98107
.observe(() -> embedAllWithContext(embedTextRequests, context));
99108
}
100109

101-
@Override
102-
public float[] embed(Document document) {
103-
return embed(document.getText());
104-
}
105-
106110
private EmbeddingResponse embedAllWithContext(List<EmbedTextRequest> embedTextRequests,
107111
EmbeddingModelObservationContext context) {
108112
String modelId = null;

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import org.springframework.ai.chat.metadata.DefaultUsage;
2929
import org.springframework.ai.document.Document;
30+
import org.springframework.ai.document.MetadataMode;
3031
import org.springframework.ai.embedding.AbstractEmbeddingModel;
3132
import org.springframework.ai.embedding.Embedding;
3233
import org.springframework.ai.embedding.EmbeddingModel;
@@ -78,6 +79,14 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
7879

7980
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
8081
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
82+
this(ollamaApi, MetadataMode.EMBED, defaultOptions, observationRegistry, modelManagementOptions);
83+
}
84+
85+
public OllamaEmbeddingModel(OllamaApi ollamaApi, MetadataMode metadataMode, OllamaOptions defaultOptions,
86+
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
87+
88+
super(metadataMode);
89+
8190
Assert.notNull(ollamaApi, "ollamaApi must not be null");
8291
Assert.notNull(defaultOptions, "options must not be null");
8392
Assert.notNull(observationRegistry, "observationRegistry must not be null");
@@ -95,11 +104,6 @@ public static Builder builder() {
95104
return new Builder();
96105
}
97106

98-
@Override
99-
public float[] embed(Document document) {
100-
return embed(document.getText());
101-
}
102-
103107
@Override
104108
public EmbeddingResponse call(EmbeddingRequest request) {
105109
Assert.notEmpty(request.getInstructions(), "At least one text is required!");

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {
6565

6666
private final OpenAiApi openAiApi;
6767

68-
private final MetadataMode metadataMode;
69-
7068
/**
7169
* Observation registry used for instrumentation.
7270
*/
@@ -128,25 +126,20 @@ public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, Open
128126
*/
129127
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
130128
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
129+
130+
super(metadataMode);
131+
131132
Assert.notNull(openAiApi, "openAiApi must not be null");
132-
Assert.notNull(metadataMode, "metadataMode must not be null");
133133
Assert.notNull(options, "options must not be null");
134134
Assert.notNull(retryTemplate, "retryTemplate must not be null");
135135
Assert.notNull(observationRegistry, "observationRegistry must not be null");
136136

137137
this.openAiApi = openAiApi;
138-
this.metadataMode = metadataMode;
139138
this.defaultOptions = options;
140139
this.retryTemplate = retryTemplate;
141140
this.observationRegistry = observationRegistry;
142141
}
143142

144-
@Override
145-
public float[] embed(Document document) {
146-
Assert.notNull(document, "Document must not be null");
147-
return this.embed(document.getFormattedContent(this.metadataMode));
148-
}
149-
150143
@Override
151144
public EmbeddingResponse call(EmbeddingRequest request) {
152145
// Before moving any further, build the final request EmbeddingRequest,

models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingModel.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import org.springframework.ai.chat.metadata.EmptyUsage;
2727
import org.springframework.ai.document.Document;
28+
import org.springframework.ai.document.MetadataMode;
2829
import org.springframework.ai.embedding.AbstractEmbeddingModel;
2930
import org.springframework.ai.embedding.Embedding;
3031
import org.springframework.ai.embedding.EmbeddingOptions;
@@ -75,6 +76,21 @@ public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOp
7576
*/
7677
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options,
7778
boolean createExtension) {
79+
this(jdbcTemplate, MetadataMode.EMBED, options, createExtension);
80+
}
81+
82+
/**
83+
* a PostgresMlEmbeddingModel constructor
84+
* @param jdbcTemplate JdbcTemplate to use to interact with the database.
85+
* @param metadataMode MetadataMode describing what metadata values are included in
86+
* the embedding.
87+
* @param options PostgresMlEmbeddingOptions to configure the client.
88+
*/
89+
public PostgresMlEmbeddingModel(JdbcTemplate jdbcTemplate, MetadataMode metadataMode,
90+
PostgresMlEmbeddingOptions options, boolean createExtension) {
91+
92+
super(metadataMode);
93+
7894
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
7995
Assert.notNull(options, "options must not be null.");
8096
Assert.notNull(options.getTransformer(), "transformer must not be null.");
@@ -96,11 +112,6 @@ public float[] embed(String text) {
96112
ModelOptionsUtils.toJsonString(this.defaultOptions.getKwargs()));
97113
}
98114

99-
@Override
100-
public float[] embed(Document document) {
101-
return this.embed(document.getFormattedContent(this.defaultOptions.getMetadataMode()));
102-
}
103-
104115
@SuppressWarnings("null")
105116
@Override
106117
public EmbeddingResponse call(EmbeddingRequest request) {

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ public class QianFanEmbeddingModel extends AbstractEmbeddingModel {
6363

6464
private final QianFanApi qianFanApi;
6565

66-
private final MetadataMode metadataMode;
67-
6866
/**
6967
* Observation registry used for instrumentation.
7068
*/
@@ -126,25 +124,20 @@ public QianFanEmbeddingModel(QianFanApi qianFanApi, MetadataMode metadataMode,
126124
*/
127125
public QianFanEmbeddingModel(QianFanApi qianFanApi, MetadataMode metadataMode, QianFanEmbeddingOptions options,
128126
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
127+
128+
super(metadataMode);
129+
129130
Assert.notNull(qianFanApi, "QianFanApi must not be null");
130-
Assert.notNull(metadataMode, "metadataMode must not be null");
131131
Assert.notNull(options, "options must not be null");
132132
Assert.notNull(retryTemplate, "retryTemplate must not be null");
133133
Assert.notNull(observationRegistry, "observationRegistry must not be null");
134134

135135
this.qianFanApi = qianFanApi;
136-
this.metadataMode = metadataMode;
137136
this.defaultOptions = options;
138137
this.retryTemplate = retryTemplate;
139138
this.observationRegistry = observationRegistry;
140139
}
141140

142-
@Override
143-
public float[] embed(Document document) {
144-
Assert.notNull(document, "Document must not be null");
145-
return this.embed(document.getFormattedContent(this.metadataMode));
146-
}
147-
148141
@Override
149142
public EmbeddingResponse call(EmbeddingRequest request) {
150143
QianFanEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);

0 commit comments

Comments
 (0)