Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Jemin Huh
* @author Mick Semb Wever
* @author Jihoon Kim
* @author Enrico Rampazzo
Expand Down Expand Up @@ -83,8 +84,10 @@ void addAndGet() {
memory.deleteByConversationId(sessionId);
assertThat(memory.findByConversationId(sessionId)).isEmpty();

AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(),
List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments")));
AssistantMessage assistantMessage = AssistantMessage.builder()
.text("test answer")
.toolCalls(List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments")))
.build();

memory.saveAll(sessionId, List.of(userMessage, assistantMessage));
messages = memory.findByConversationId(sessionId);
Expand Down Expand Up @@ -112,10 +115,11 @@ void addAndGet() {
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator()
.isEqualTo(media);
memory.deleteByConversationId(sessionId);
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(
List.of(new ToolResponse("id", "name", "responseData"),
new ToolResponse("id2", "name2", "responseData2")),
Map.of("id", "id", "metadataKey", "metadata"));
ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("id", "name", "responseData"),
new ToolResponse("id2", "name2", "responseData2")))
.metadata(Map.of("id", "id", "metadataKey", "metadata"))
.build();
memory.saveAll(sessionId, List.of(toolResponseMessage));
messages = memory.findByConversationId(sessionId);
assertThat(messages.size()).isEqualTo(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ private Message getMessage(UdtValue udt) {
Map<String, Object> props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn));
switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) {
case ASSISTANT:
return new AssistantMessage(content, props);
return AssistantMessage.builder().text(content).metadata(props).build();
case USER:
return UserMessage.builder().text(content).metadata(props).build();
case SYSTEM:
return SystemMessage.builder().text(content).metadata(props).build();
case TOOL:
// todo – persist ToolResponse somehow
return new ToolResponseMessage(List.of(), props);
return ToolResponseMessage.builder().responses(List.of()).metadata(props).build();
default:
throw new IllegalStateException(
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,25 +172,28 @@ public Neo4jChatMemoryRepositoryConfig getConfig() {

private Message buildToolMessage(org.neo4j.driver.Record record) {
Message message;
message = new ToolResponseMessage(record.get("toolResponses").asList(v -> {
message = ToolResponseMessage.builder().responses(record.get("toolResponses").asList(v -> {
Map<String, Object> trMap = v.asMap();
return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()),
(String) trMap.get(ToolResponseAttributes.NAME.getValue()),
(String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue()));
}), record.get("metadata").asMap());
})).metadata(record.get("metadata").asMap()).build();
return message;
}

private Message buildAssistantMessage(org.neo4j.driver.Record record, Map<String, Object> messageMap,
List<Media> mediaList) {
Message message;
message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(),
record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> {
var toolCallMap = v.asMap();
return new AssistantMessage.ToolCall((String) toolCallMap.get("id"),
(String) toolCallMap.get("type"), (String) toolCallMap.get("name"),
(String) toolCallMap.get("arguments"));
}), mediaList);
message = AssistantMessage.builder()
.text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString())
.metadata(record.get("metadata").asMap(Map.of()))
.toolCalls(record.get("toolCalls").asList(v -> {
var toolCallMap = v.asMap();
return new AssistantMessage.ToolCall((String) toolCallMap.get("id"), (String) toolCallMap.get("type"),
(String) toolCallMap.get("name"), (String) toolCallMap.get("arguments"));
}))
.media(mediaList)
.build();
return message;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
/**
* Integration tests for {@link Neo4jChatMemoryRepository}.
*
* @author Jemin Huh
* @author Enrico Rampazzo
* @since 1.0.0
*/
Expand Down Expand Up @@ -263,9 +264,11 @@ void handleMediaContent() {
void handleAssistantMessageWithToolCalls() {
var conversationId = UUID.randomUUID().toString();

AssistantMessage assistantMessage = new AssistantMessage("Message with tool calls", Map.of(),
List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"),
new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2")));
AssistantMessage assistantMessage = AssistantMessage.builder()
.text("Message with tool calls")
.toolCalls(List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"),
new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2")))
.build();

this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(assistantMessage));

Expand All @@ -282,9 +285,11 @@ void handleAssistantMessageWithToolCalls() {
void handleToolResponseMessage() {
var conversationId = UUID.randomUUID().toString();

ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List
.of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")),
Map.of("metadataKey", "metadataValue"));
ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder()
.responses(List.of(new ToolResponse("id1", "name1", "responseData1"),
new ToolResponse("id2", "name2", "responseData2")))
.metadata(Map.of("metadataKey", "metadataValue"))
.build();

this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
/**
* The {@link ChatModel} implementation for the Anthropic service.
*
* @author Jemin Huh
* @author Christian Tzolov
* @author luocongqiu
* @author Mariusz Bernacki
Expand Down Expand Up @@ -302,19 +303,21 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
for (ContentBlock content : chatCompletion.content()) {
switch (content.type()) {
case TEXT, TEXT_DELTA:
generations.add(new Generation(new AssistantMessage(content.text(), Map.of()),
generations.add(new Generation(new AssistantMessage(content.text()),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
break;
case THINKING, THINKING_DELTA:
Map<String, Object> thinkingProperties = new HashMap<>();
thinkingProperties.put("signature", content.signature());
generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties),
generations.add(new Generation(
AssistantMessage.builder().text(content.thinking()).metadata(thinkingProperties).build(),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
break;
case REDACTED_THINKING:
Map<String, Object> redactedProperties = new HashMap<>();
redactedProperties.put("data", content.data());
generations.add(new Generation(new AssistantMessage(null, redactedProperties),
generations.add(new Generation(
AssistantMessage.builder().text((String) null).metadata(redactedProperties).build(),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
break;
case TOOL_USE:
Expand All @@ -328,13 +331,13 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
}

if (chatCompletion.stopReason() != null && generations.isEmpty()) {
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
Generation generation = new Generation(new AssistantMessage((String) null),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
generations.add(generation);
}

if (!CollectionUtils.isEmpty(toolCalls)) {
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
AssistantMessage assistantMessage = AssistantMessage.builder().text("").toolCalls(toolCalls).build();
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
generations.add(toolCallGeneration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
* {@link OpenAIClient}.
*
* @author Jemin Huh
* @author Mark Pollack
* @author Ueibin Kim
* @author John Blum
Expand Down Expand Up @@ -485,7 +486,7 @@ private Generation buildGeneration(ChatChoice choice, Map<String, Object> metada
}

var content = responseMessage == null ? "" : responseMessage.getContent();
var assistantMessage = new AssistantMessage(content, metadata, toolCalls);
var assistantMessage = AssistantMessage.builder().text(content).metadata(metadata).toolCalls(toolCalls).build();
var generationMetadata = generateChoiceMetadata(choice);

return new Generation(assistantMessage, generationMetadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Set;

import io.micrometer.observation.Observation;
Expand Down Expand Up @@ -128,6 +127,7 @@
* <p>
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
*
* @author Jemin Huh
* @author Christian Tzolov
* @author Wei Jiang
* @author Alexandros Pappas
Expand Down Expand Up @@ -566,14 +566,14 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv
List<Generation> generations = message.content()
.stream()
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
.map(content -> new Generation(new AssistantMessage(content.text()),
ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build()))
.toList();

List<Generation> allGenerations = new ArrayList<>(generations);

if (response.stopReasonAsString() != null && generations.isEmpty()) {
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
Generation generation = new Generation(new AssistantMessage((String) null),
ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build());
allGenerations.add(generation);
}
Expand All @@ -597,7 +597,7 @@ private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perv
.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
}

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
AssistantMessage assistantMessage = AssistantMessage.builder().text("").toolCalls(toolCalls).build();
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build());
allGenerations.add(toolCallGeneration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
/**
* Amazon Bedrock Converse API utils.
*
* @author Jemin Huh
* @author Wei Jiang
* @author Christian Tzolov
* @author Alexandros Pappas
Expand Down Expand Up @@ -140,7 +141,7 @@ public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> respo
}
}

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
AssistantMessage assistantMessage = AssistantMessage.builder().text("").toolCalls(toolCalls).build();
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.builder().finishReason("tool_use").build());

Expand Down Expand Up @@ -175,8 +176,7 @@ else if (nextEvent instanceof ContentBlockStartEvent contentBlockStartEvent) {
else if (nextEvent instanceof ContentBlockDeltaEvent contentBlockDeltaEvent) {
if (contentBlockDeltaEvent.delta().type().equals(ContentBlockDelta.Type.TEXT)) {

var generation = new Generation(
new AssistantMessage(contentBlockDeltaEvent.delta().text(), Map.of()),
var generation = new Generation(new AssistantMessage(contentBlockDeltaEvent.delta().text()),
ChatGenerationMetadata.builder()
.finishReason(lastAggregation.metadataAggregation().stopReason())
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.content.Media;

/**
* Represents an assistant message generated by the DeepSeek model.
*
* @author Jemin Huh
* @author Soby Chacko
* @author Mark Pollack
* @since 1.0.0
*/
public class DeepSeekAssistantMessage extends AssistantMessage {

private Boolean prefix;
Expand All @@ -38,22 +46,14 @@ public DeepSeekAssistantMessage(String content, String reasoningContent) {
this.reasoningContent = reasoningContent;
}

public DeepSeekAssistantMessage(String content, Map<String, Object> properties) {
super(content, properties);
}

public DeepSeekAssistantMessage(String content, Map<String, Object> properties, List<ToolCall> toolCalls) {
super(content, properties, toolCalls);
}

public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties,
public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> metadata,
List<ToolCall> toolCalls) {
this(content, reasoningContent, properties, toolCalls, List.of());
this(content, reasoningContent, metadata, toolCalls, List.of());
}

public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties,
public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> metadata,
List<ToolCall> toolCalls, List<Media> media) {
super(content, properties, toolCalls, media);
super(content, metadata, toolCalls, media);
this.reasoningContent = reasoningContent;
}

Expand Down Expand Up @@ -102,9 +102,9 @@ public int hashCode() {

@Override
public String toString() {
return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + super.getToolCalls()
+ ", textContent=" + this.textContent + ", reasoningContent=" + this.reasoningContent + ", prefix="
+ this.prefix + ", metadata=" + this.metadata + "]";
return "AssistantMessage [messageType=" + this.messageType + ", toolCalls=" + this.toolCalls + ", textContent="
+ this.textContent + ", reasoningContent=" + this.reasoningContent + ", prefix=" + this.prefix
+ ", metadata=" + this.metadata + "]";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
* Endpoints for text generation.
*
* @author Jemin Huh
* @author Mark Pollack
* @author Jihoon Kim
*/
Expand Down Expand Up @@ -104,7 +105,8 @@ public ChatResponse call(Prompt prompt) {
new TypeReference<Map<String, Object>>() {

});
Generation generation = new Generation(new AssistantMessage(generatedText, detailsMap));
Generation generation = new Generation(
AssistantMessage.builder().text(generatedText).metadata(detailsMap).build());
generations.add(generation);
}
return new ChatResponse(generations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax}
* backed by {@link MiniMaxApi}.
*
* @author Jemin Huh
* @author Geng Rong
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
Expand Down Expand Up @@ -225,7 +226,11 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met
acc1.addAll(acc2);
return acc1;
});
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
var assistantMessage = AssistantMessage.builder()
.text(choice.message().content())
.metadata(metadata)
.toolCalls(toolCalls)
.build();
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
Expand Down Expand Up @@ -424,7 +429,11 @@ private Generation buildGeneration(ChatCompletionMessage message, ChatCompletion
toolCall.function().name(), toolCall.function().arguments()))
.toList();

var assistantMessage = new AssistantMessage(message.content(), metadata, toolCalls);
var assistantMessage = AssistantMessage.builder()
.text(message.content())
.metadata(metadata)
.toolCalls(toolCalls)
.build();
String finishReason = (completionFinishReason != null ? completionFinishReason.name() : "");
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,11 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata)
toolCall.function().name(), toolCall.function().arguments()))
.toList();

var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
var assistantMessage = AssistantMessage.builder()
.text(choice.message().content())
.metadata(metadata)
.toolCalls(toolCalls)
.build();
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
return new Generation(assistantMessage, generationMetadata);
Expand Down
Loading