Skip to content

Commit 263fe2f

Browse files
ThomasVitalemarkpollack
authored andcommitted
Modular RAG - Query Analysis
Query Analysis * Introduce Query Analysis Module * Define QueryTransformer API and TranslationQueryTransformer implementation * Define QueryExpander API and MultiQueryExpander implementation * Support QueryTransformer in RetrievalAugmentationAdvisor (support for QueryExpander will be in the next PR together with the needed DocumentFuser API). Improvements * Refine Retrieval and Augmentation Modules for increased robustness * Expand test coverage for both modules * Define clone() method for ChatClient.Builder Tests * Introduce “spring-ai-integration-tests” for full-fledged integration tests * Add integration tests for RAG modules * Add integration tests for RAG advisor Query Analysis * Introduce Query Analysis Module * Define QueryTransformer API and TranslationQueryTransformer implementation * Define QueryExpander API and MultiQueryExpander implementation * Support QueryTransformer in RetrievalAugmentationAdvisor (support for QueryExpander will be in the next PR together with the needed DocumentFuser API). Improvements * Refine Retrieval and Augmentation Modules for increased robustness * Expand test coverage for both modules * Define clone() method for ChatClient.Builder Tests * Introduce “spring-ai-integration-tests” for full-fledged integration tests * Add integration tests for RAG modules * Add integration tests for RAG advisor Relates to #gh-1603 Signed-off-by: Thomas Vitale <[email protected]>
1 parent b4e0a45 commit 263fe2f

File tree

38 files changed

+1710
-72
lines changed

38 files changed

+1710
-72
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat
18+
19+
import org.assertj.core.api.Assertions.assertThat
20+
import org.junit.jupiter.api.Test
21+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable
22+
import org.slf4j.LoggerFactory
23+
import org.springframework.ai.chat.messages.UserMessage
24+
import org.springframework.ai.chat.prompt.Prompt
25+
import org.springframework.ai.model.function.FunctionCallback
26+
import org.springframework.ai.model.function.FunctionCallbackWrapper
27+
import org.springframework.ai.openai.OpenAiChatModel
28+
import org.springframework.ai.openai.OpenAiChatOptions
29+
import org.springframework.ai.openai.api.OpenAiApi
30+
import org.springframework.beans.factory.annotation.Autowired
31+
import org.springframework.boot.SpringBootConfiguration
32+
import org.springframework.boot.autoconfigure.AutoConfigurations
33+
import org.springframework.boot.test.context.SpringBootTest
34+
import org.springframework.boot.test.context.runner.ApplicationContextRunner
35+
import org.springframework.context.annotation.Bean
36+
37+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
38+
class FunctionCallbackWrapperKotlinIT {
39+
40+
private val logger = LoggerFactory.getLogger(FunctionCallbackWrapperKotlinIT::class.java)
41+
42+
43+
private val contextRunner = ApplicationContextRunner()
44+
.withUserConfiguration(Config::class.java)
45+
46+
47+
@Test
48+
fun functionCallTest() {
49+
this.contextRunner.run {context ->
50+
51+
val chatModel = context.getBean(OpenAiChatModel::class.java)
52+
assertThat(chatModel).isNotNull
53+
54+
val userMessage = UserMessage(
55+
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.")
56+
57+
val response = chatModel
58+
.call(Prompt(listOf(userMessage), OpenAiChatOptions.builder().withFunction("WeatherInfo").build()))
59+
60+
logger.info("Response: " + response)
61+
62+
assertThat(response.getResult().output.content).contains("30", "10", "15")
63+
}
64+
}
65+
66+
67+
@SpringBootConfiguration
68+
open class Config {
69+
70+
@Bean
71+
open fun chatCompletionApi(): OpenAiApi {
72+
return OpenAiApi(System.getenv("OPENAI_API_KEY"))
73+
}
74+
75+
@Bean
76+
open fun openAiClient(openAiApi: OpenAiApi): OpenAiChatModel {
77+
return OpenAiChatModel(openAiApi)
78+
}
79+
80+
@Bean
81+
open fun weatherFunctionInfo(): FunctionCallback {
82+
83+
return FunctionCallbackWrapper.builder(MockKotlinWeatherService())
84+
.withName("WeatherInfo")
85+
.withInputType(KotlinRequest::class.java)
86+
.withDescription(
87+
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
88+
.build();
89+
}
90+
}
91+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat
18+
19+
import com.fasterxml.jackson.annotation.JsonClassDescription
20+
import com.fasterxml.jackson.annotation.JsonInclude
21+
import com.fasterxml.jackson.annotation.JsonInclude.Include
22+
import com.fasterxml.jackson.annotation.JsonProperty
23+
import com.fasterxml.jackson.annotation.JsonPropertyDescription
24+
25+
class MockKotlinWeatherService : Function1<KotlinRequest, KotlinResponse> {
26+
27+
override fun invoke(kotlinRequest: KotlinRequest): KotlinResponse {
28+
var temperature = 10.0
29+
if (kotlinRequest.location.contains("Paris")) {
30+
temperature = 15.0
31+
}
32+
else if (kotlinRequest.location.contains("Tokyo")) {
33+
temperature = 10.0
34+
}
35+
else if (kotlinRequest.location.contains("San Francisco")) {
36+
temperature = 30.0
37+
}
38+
39+
return KotlinResponse(temperature, 15.0, 20.0, 2.0, 53, 45, Unit.C);
40+
}
41+
}
42+
43+
/**
44+
* Temperature units.
45+
*/
46+
enum class Unit(val unitName: String) {
47+
48+
/**
49+
* Celsius.
50+
*/
51+
C("metric"),
52+
/**
53+
* Fahrenheit.
54+
*/
55+
F("imperial");
56+
}
57+
58+
/**
59+
* Weather Function request.
60+
*/
61+
@JsonInclude(Include.NON_NULL)
62+
@JsonClassDescription("Weather API request")
63+
data class KotlinRequest(
64+
@get:JsonProperty(required = true, value = "location")
65+
@get:JsonPropertyDescription("The city and state e.g. San Francisco, CA")
66+
val location: String = "",
67+
68+
@get:JsonProperty(required = true, value = "lat")
69+
@get:JsonPropertyDescription("The city latitude")
70+
val lat: Double = 0.0,
71+
72+
@get:JsonProperty(required = true, value = "lon")
73+
@get:JsonPropertyDescription("The city longitude")
74+
val lon: Double = 0.0,
75+
76+
@get:JsonProperty(required = true, value = "unit")
77+
@get:JsonPropertyDescription("Temperature unit")
78+
val unit: Unit = Unit.C
79+
)
80+
81+
82+
/**
83+
* Weather Function response.
84+
*/
85+
data class KotlinResponse(val temp: Double,
86+
val feels_like: Double,
87+
val temp_min: Double,
88+
val temp_max: Double,
89+
val pressure: Int,
90+
val humidity: Int,
91+
val unit: Unit
92+
)

models/spring-ai-postgresml/.attach_pid127642

Whitespace-only changes.

pom.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
<module>spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai</module>
124124
<module>spring-ai-spring-boot-starters/spring-ai-starter-zhipuai</module>
125125
<module>spring-ai-spring-boot-starters/spring-ai-starter-moonshot</module>
126+
127+
<module>spring-ai-integration-tests</module>
126128
</modules>
127129

128130
<organization>

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ <I, O> Builder defaultFunction(String name, String description,
289289

290290
Builder defaultToolContext(Map<String, Object> toolContext);
291291

292+
Builder clone();
293+
292294
ChatClient build();
293295

294296
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ public ChatClient build() {
7171
return new DefaultChatClient(this.defaultRequest);
7272
}
7373

74+
public Builder clone() {
75+
return this.defaultRequest.mutate();
76+
}
77+
7478
public Builder defaultAdvisors(Advisor... advisors) {
7579
this.defaultRequest.advisors(advisors);
7680
return this;

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.chat.client.advisor;
1818

19+
import java.util.ArrayList;
20+
import java.util.Arrays;
1921
import java.util.HashMap;
2022
import java.util.List;
2123
import java.util.Map;
@@ -35,15 +37,16 @@
3537
import org.springframework.ai.chat.prompt.PromptTemplate;
3638
import org.springframework.ai.document.Document;
3739
import org.springframework.ai.rag.Query;
40+
import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer;
3841
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
3942
import org.springframework.ai.rag.augmentation.QueryAugmentor;
40-
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
43+
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
4144
import org.springframework.lang.Nullable;
4245
import org.springframework.util.Assert;
4346
import org.springframework.util.StringUtils;
4447

4548
/**
46-
* This advisor implements common Retrieval Augmented Generation (RAG) flows using the
49+
* Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
4750
* building blocks defined in the {@link org.springframework.ai.rag} package and following
4851
* the Modular RAG Architecture.
4952
* <p>
@@ -55,10 +58,12 @@
5558
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
5659
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
5760
*/
58-
public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
61+
public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
5962

6063
public static final String DOCUMENT_CONTEXT = "rag_document_context";
6164

65+
private final List<QueryTransformer> queryTransformers;
66+
6267
private final DocumentRetriever documentRetriever;
6368

6469
private final QueryAugmentor queryAugmentor;
@@ -67,12 +72,15 @@ public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAr
6772

6873
private final int order;
6974

70-
public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable QueryAugmentor queryAugmentor,
71-
@Nullable Boolean protectFromBlocking, @Nullable Integer order) {
75+
public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, DocumentRetriever documentRetriever,
76+
@Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
77+
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
78+
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
7279
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
80+
this.queryTransformers = queryTransformers;
7381
this.documentRetriever = documentRetriever;
7482
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
75-
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
83+
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true;
7684
this.order = order != null ? order : 0;
7785
}
7886

@@ -119,30 +127,45 @@ private AdvisedRequest before(AdvisedRequest request) {
119127
Map<String, Object> context = new HashMap<>(request.adviseContext());
120128

121129
// 0. Create a query from the user text and parameters.
122-
Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
130+
Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
131+
132+
// 1. Transform original user query based on a chain of query transformers.
133+
Query transformedQuery = originalQuery;
134+
for (var queryTransformer : this.queryTransformers) {
135+
transformedQuery = queryTransformer.apply(transformedQuery);
136+
}
123137

124-
// 1. Retrieve similar documents for the original query.
125-
List<Document> documents = this.documentRetriever.retrieve(query);
138+
// 2. Retrieve similar documents for the original query.
139+
List<Document> documents = this.documentRetriever.retrieve(transformedQuery);
126140
context.put(DOCUMENT_CONTEXT, documents);
127141

128-
// 2. Augment user query with the contextual data.
129-
Query augmentedQuery = this.queryAugmentor.augment(query, documents);
142+
// 3. Augment user query with the document contextual data.
143+
Query augmentedQuery = this.queryAugmentor.augment(transformedQuery, documents);
130144

131145
return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build();
132146
}
133147

134148
private AdvisedResponse after(AdvisedResponse advisedResponse) {
135-
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
149+
ChatResponse.Builder chatResponseBuilder;
150+
if (advisedResponse.response() == null) {
151+
chatResponseBuilder = ChatResponse.builder();
152+
}
153+
else {
154+
chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
155+
}
136156
chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
137157
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
138158
}
139159

140160
private Predicate<AdvisedResponse> onFinishReason() {
141-
return advisedResponse -> advisedResponse.response()
142-
.getResults()
143-
.stream()
144-
.anyMatch(result -> result != null && result.getMetadata() != null
145-
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
161+
return advisedResponse -> {
162+
ChatResponse chatResponse = advisedResponse.response();
163+
return chatResponse != null && chatResponse.getResults() != null
164+
&& chatResponse.getResults()
165+
.stream()
166+
.anyMatch(result -> result != null && result.getMetadata() != null
167+
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
168+
};
146169
}
147170

148171
@Override
@@ -157,6 +180,8 @@ public int getOrder() {
157180

158181
public static final class Builder {
159182

183+
private final List<QueryTransformer> queryTransformers = new ArrayList<>();
184+
160185
private DocumentRetriever documentRetriever;
161186

162187
private QueryAugmentor queryAugmentor;
@@ -168,6 +193,18 @@ public static final class Builder {
168193
private Builder() {
169194
}
170195

196+
public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
197+
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
198+
this.queryTransformers.addAll(queryTransformers);
199+
return this;
200+
}
201+
202+
public Builder queryTransformers(QueryTransformer... queryTransformers) {
203+
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
204+
this.queryTransformers.addAll(Arrays.asList(queryTransformers));
205+
return this;
206+
}
207+
171208
public Builder documentRetriever(DocumentRetriever documentRetriever) {
172209
this.documentRetriever = documentRetriever;
173210
return this;
@@ -189,7 +226,7 @@ public Builder order(Integer order) {
189226
}
190227

191228
public RetrievalAugmentationAdvisor build() {
192-
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.queryAugmentor,
229+
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.documentRetriever, this.queryAugmentor,
193230
this.protectFromBlocking, this.order);
194231
}
195232

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
/**
18+
* RAG Module: Query Analysis.
19+
* <p>
20+
* This package encompasses all components involved in the pre-retrieval phase of a
21+
* retrieval augmented generation flow. Queries are transformed, expanded, or constructed
22+
* so to enhance the effectiveness and accuracy of the subsequent retrieval phase.
23+
*/
24+
@NonNullApi
25+
@NonNullFields
26+
package org.springframework.ai.rag.analysis;
27+
28+
import org.springframework.lang.NonNullApi;
29+
import org.springframework.lang.NonNullFields;

0 commit comments

Comments
 (0)