Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
<module>vector-stores/spring-ai-chroma</module>
<module>vector-stores/spring-ai-azure</module>
<module>vector-stores/spring-ai-weaviate</module>
<module>spring-ai-vertex-ai</module>

</modules>

Expand Down Expand Up @@ -81,7 +82,7 @@
<maven.compiler.target>17</maven.compiler.target>

<!-- production dependencies -->
<spring-boot.version>3.1.3</spring-boot.version>
<spring-boot.version>3.2.0</spring-boot.version>
<stringtemplate.version>4.0.2</stringtemplate.version>
<open-ai-client.version>0.16.0</open-ai-client.version>
<azure-open-ai-client.version>1.0.0-beta.3</azure-open-ai-client.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void azureOpenAiMetadataCapturedDuringGeneration() {
Generation generation = response.getGeneration();

assertThat(generation).isNotNull()
.extracting(Generation::getText)
.extracting(Generation::getContent)
.isEqualTo("No! You will actually land with a resounding thud. This is the way!");

assertPromptMetadata(response);
Expand Down
6 changes: 6 additions & 0 deletions spring-ai-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
</dependency>


<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webflux</artifactId>
<version>6.1.1</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-messaging</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public interface AiClient {

default String generate(String message) {
Prompt prompt = new Prompt(new UserMessage(message));
return generate(prompt).getGeneration().getText();
return generate(prompt).getGeneration().getContent();
}

AiResponse generate(Prompt prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,42 @@
*/
package org.springframework.ai.client;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.springframework.ai.metadata.GenerationMetadata;
import org.springframework.ai.metadata.PromptMetadata;
import org.springframework.lang.Nullable;

/**
* The chat completion (e.g. generation) response returned by an AI provider.
*/
public class AiResponse {

private final GenerationMetadata metadata;

/**
* List of generated messages returned by the AI provider.
*/
private final List<Generation> generations;

private PromptMetadata promptMetadata;

/**
* Construct a new {@link AiResponse} instance without metadata.
* @param generations the {@link List} of {@link Generation} returned by the AI
* provider.
*/
public AiResponse(List<Generation> generations) {
this(generations, GenerationMetadata.NULL);
}

/**
* Construct a new {@link AiResponse} instance.
* @param generations the {@link List} of {@link Generation} returned by the AI
* provider.
* @param metadata {@link GenerationMetadata} containing information about the use of
* the AI provider's API.
*/
public AiResponse(List<Generation> generations, GenerationMetadata metadata) {
this.metadata = metadata;
this.generations = List.copyOf(generations);
Expand All @@ -51,23 +67,22 @@ public List<Generation> getGenerations() {
return this.generations;
}

/**
* @return Returns the first {@link Generation} in the generations list.
*/
public Generation getGeneration() {
return this.generations.get(0);
}

/**
* Returns {@link GenerationMetadata} containing information about the use of the AI
* provider's API.
* @return {@link GenerationMetadata} containing information about the use of the AI
* provider's API.
* @return Returns {@link GenerationMetadata} containing information about the use of
* the AI provider's API.
*/
public GenerationMetadata getGenerationMetadata() {
return this.metadata;
}

/**
* Returns {@link PromptMetadata} containing information on prompt processing by the
* AI.
* @return {@link PromptMetadata} containing information on prompt processing by the
* AI.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,27 @@
import java.util.Map;

import org.springframework.ai.metadata.ChoiceMetadata;
import org.springframework.ai.prompt.messages.AbstractMessage;
import org.springframework.ai.prompt.messages.MessageType;
import org.springframework.lang.Nullable;

public class Generation {

// Just text for now
private final String text;

private Map<String, Object> info;
/**
* Represents a response returned by the AI.
*/
public class Generation extends AbstractMessage {

private ChoiceMetadata choiceMetadata;

public Generation(String text) {
this(text, Collections.emptyMap());
}

public Generation(String text, Map<String, Object> info) {
this.text = text;
this.info = Map.copyOf(info);
}

public String getText() {
return this.text;
public Generation(String content, Map<String, Object> properties) {
super(MessageType.ASSISTANT, content, properties);
}

public Map<String, Object> getInfo() {
return this.info;
public Generation(String content, Map<String, Object> properties, MessageType type) {
super(type, content, properties);
}

public ChoiceMetadata getChoiceMetadata() {
Expand All @@ -60,7 +55,7 @@ public Generation withChoiceMetadata(@Nullable ChoiceMetadata choiceMetadata) {

@Override
public String toString() {
return "Generation{" + "text='" + text + '\'' + ", info=" + info + '}';
return "Generation{" + "text='" + content + '\'' + ", info=" + properties + '}';
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,40 @@ public String getMessageTypeValue() {
return this.messageType.getValue();
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((content == null) ? 0 : content.hashCode());
result = prime * result + ((properties == null) ? 0 : properties.hashCode());
result = prime * result + ((messageType == null) ? 0 : messageType.hashCode());
return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
AbstractMessage other = (AbstractMessage) obj;
if (content == null) {
if (other.content != null)
return false;
}
else if (!content.equals(other.content))
return false;
if (properties == null) {
if (other.properties != null)
return false;
}
else if (!properties.equals(other.properties))
return false;
if (messageType != other.messageType)
return false;
return true;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public List<Document> apply(List<Document> documents) {

var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount));
Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getContent()));
String keywords = this.aiClient.generate(prompt).getGeneration().getText();
String keywords = this.aiClient.generate(prompt).getGeneration().getContent();
document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords));
}
return documents;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public List<Document> apply(List<Document> documents) {

Prompt prompt = new PromptTemplate(this.summaryTemplate)
.create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContext));
documentSummaries.add(this.aiClient.generate(prompt).getGeneration().getText());
documentSummaries.add(this.aiClient.generate(prompt).getGeneration().getContent());
}

for (int i = 0; i < documentSummaries.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void generateWithStringCallsGenerateWithPromptAndReturnsResponseCorrectly() {
verify(mockClient, times(1)).generate(eq(userMessage));
verify(mockClient, times(1)).generate(isA(Prompt.class));
verify(response, times(1)).getGeneration();
verify(generation, times(1)).getText();
verify(generation, times(1)).getContent();
verifyNoMoreInteractions(mockClient, generation, response);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void helloWorldCompletion() {
""";
Prompt prompt = new Prompt(mistral7bInstruct);
AiResponse aiResponse = huggingfaceAiClient.generate(prompt);
assertThat(aiResponse.getGeneration().getText()).isNotEmpty();
assertThat(aiResponse.getGeneration().getContent()).isNotEmpty();
String expectedResponse = """
```json
{
Expand All @@ -56,9 +56,9 @@ void helloWorldCompletion() {
"address": "#1 Samuel St."
}
```""";
assertThat(aiResponse.getGeneration().getText()).isEqualTo(expectedResponse);
assertThat(aiResponse.getGeneration().getInfo()).containsKey("generated_tokens");
assertThat(aiResponse.getGeneration().getInfo()).containsEntry("generated_tokens", 39);
assertThat(aiResponse.getGeneration().getContent()).isEqualTo(expectedResponse);
assertThat(aiResponse.getGeneration().getProperties()).containsKey("generated_tokens");
assertThat(aiResponse.getGeneration().getProperties()).containsEntry("generated_tokens", 39);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public void smokeTest() {
Assertions.assertNotNull(aiResponse);
Assertions.assertFalse(CollectionUtils.isEmpty(aiResponse.getGenerations()));
Assertions.assertNotNull(aiResponse.getGeneration());
Assertions.assertNotNull(aiResponse.getGeneration().getText());
Assertions.assertNotNull(aiResponse.getGeneration().getContent());
}

private static OllamaClient getOllamaClient() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void outputParser() {
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.openAiClient.generate(prompt).getGeneration();

List<String> list = outputParser.parse(generation.getText());
List<String> list = outputParser.parse(generation.getContent());
assertThat(list).hasSize(5);

}
Expand All @@ -80,7 +80,7 @@ void mapOutputParser() {
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = openAiClient.generate(prompt).getGeneration();

Map<String, Object> result = outputParser.parse(generation.getText());
Map<String, Object> result = outputParser.parse(generation.getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));

}
Expand All @@ -99,7 +99,7 @@ void beanOutputParser() {
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = openAiClient.generate(prompt).getGeneration();

ActorsFilms actorsFilms = outputParser.parse(generation.getText());
ActorsFilms actorsFilms = outputParser.parse(generation.getContent());
}

record ActorsFilmsRecord(String actor, List<String> movies) {
Expand All @@ -119,7 +119,7 @@ void beanOutputParserRecords() {
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = openAiClient.generate(prompt).getGeneration();

ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getText());
ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public abstract class AbstractIT {

protected void evaluateQuestionAndAnswer(String question, AiResponse response, boolean factBased) {
assertThat(response).isNotNull();
String answer = response.getGeneration().getText();
String answer = response.getGeneration().getContent();
logger.info("Question: " + question);
logger.info("Answer:" + answer);
PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource,
Expand All @@ -53,12 +53,12 @@ protected void evaluateQuestionAndAnswer(String question, AiResponse response, b
}
Message userMessage = userPromptTemplate.createMessage();
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
String yesOrNo = openAiClient.generate(prompt).getGeneration().getText();
String yesOrNo = openAiClient.generate(prompt).getGeneration().getContent();
logger.info("Is Answer related to question: " + yesOrNo);
if (yesOrNo.equalsIgnoreCase("no")) {
SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource);
prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage));
String reasonForFailure = openAiClient.generate(prompt).getGeneration().getText();
String reasonForFailure = openAiClient.generate(prompt).getGeneration().getContent();
fail(reasonForFailure);
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class BasicEvaluationTest {

protected void evaluateQuestionAndAnswer(String question, AiResponse response, boolean factBased) {
assertThat(response).isNotNull();
String answer = response.getGeneration().getText();
String answer = response.getGeneration().getContent();
logger.info("Question: " + question);
logger.info("Answer:" + answer);
PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource,
Expand All @@ -69,12 +69,12 @@ protected void evaluateQuestionAndAnswer(String question, AiResponse response, b
}
Message userMessage = userPromptTemplate.createMessage();
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
String yesOrNo = openAiClient.generate(prompt).getGeneration().getText();
String yesOrNo = openAiClient.generate(prompt).getGeneration().getContent();
logger.info("Is Answer related to question: " + yesOrNo);
if (yesOrNo.equalsIgnoreCase("no")) {
SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource);
prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage));
String reasonForFailure = openAiClient.generate(prompt).getGeneration().getText();
String reasonForFailure = openAiClient.generate(prompt).getGeneration().getContent();
fail(reasonForFailure);
}
else {
Expand Down
Loading