Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ To check javadocs using the [javadoc:javadoc](https://maven.apache.org/plugins/m
./mvnw javadoc:javadoc -Pjavadoc
```

To build with checkstyles enabled.
Checkstyles are currently disabled, but you can enable them by doing the following:
```shell
./mvnw clean package -DskipTests -Ddisable.checks=false
```

## Project Links

* [Documentation](https://docs.spring.io/spring-ai/reference/)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static class DocumentVisitor extends AbstractVisitor {

private Document.Builder currentDocumentBuilder;

public DocumentVisitor(MarkdownDocumentReaderConfig config) {
DocumentVisitor(MarkdownDocumentReaderConfig config) {
this.config = config;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static Builder builder() {
return new Builder();
}

public static class Builder {
public static final class Builder {

private boolean horizontalRuleCreateDocument = false;

Expand Down
1 change: 0 additions & 1 deletion document-readers/pdf-reader/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
</scm>

<properties>
<disable.checks>false</disable.checks>
</properties>

<dependencies>
Expand Down
1 change: 0 additions & 1 deletion models/spring-ai-anthropic/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
</scm>

<properties>
<disable.checks>false</disable.checks>
</properties>

<dependencies>
Expand Down
1 change: 0 additions & 1 deletion models/spring-ai-azure-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
</scm>

<properties>
<disable.checks>false</disable.checks>
</properties>

<dependencies>
Expand Down
18 changes: 17 additions & 1 deletion models/spring-ai-bedrock-converse/pom.xml
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2023-2024 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
Expand Down Expand Up @@ -81,4 +97,4 @@

</dependencies>

</project>
</project>
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
/*
* Copyright 2024 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.bedrock.converse;

import java.io.IOException;
Expand All @@ -26,44 +27,11 @@
import java.util.Map;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
import org.springframework.ai.bedrock.converse.api.URLValidator;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
Expand Down Expand Up @@ -97,6 +65,39 @@
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;

import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
import org.springframework.ai.bedrock.converse.api.URLValidator;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;

/**
* A {@link ChatModel} implementation that uses the Amazon Bedrock Converse API to
* interact with the <a href=
Expand Down Expand Up @@ -335,7 +336,7 @@ else if (prompt.getOptions() instanceof ChatOptions) {
.topP(updatedRuntimeOptions.getTopP() != null ? updatedRuntimeOptions.getTopP().floatValue() : null)
.build();
Document additionalModelRequestFields = ConverseApiUtils
.getChatOptionsAdditionalModelRequestFields(defaultOptions, prompt.getOptions());
.getChatOptionsAdditionalModelRequestFields(this.defaultOptions, prompt.getOptions());

return ConverseRequest.builder()
.modelId(updatedRuntimeOptions.getModel())
Expand Down Expand Up @@ -411,10 +412,8 @@ private ChatResponse toChatResponse(ConverseResponse response) {
List<Generation> generations = message.content()
.stream()
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
.map(content -> {
return new Generation(new AssistantMessage(content.text(), Map.of()),
ChatGenerationMetadata.from(response.stopReasonAsString(), null));
})
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
ChatGenerationMetadata.from(response.stopReasonAsString(), null)))
.toList();

List<Generation> allGenerations = new ArrayList<>(generations);
Expand Down Expand Up @@ -508,7 +507,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// @formatter:off
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response);

Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
&& this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
Expand Down Expand Up @@ -540,14 +539,14 @@ public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseS
Sinks.Many<ConverseStreamOutput> eventSink = Sinks.many().multicast().onBackpressureBuffer();

ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder()
.onDefault((output) -> {
.onDefault(output -> {
logger.debug("Received converse stream output:{}", output);
eventSink.tryEmitNext(output);
})
.build();

ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder()
.onEventStream(stream -> stream.subscribe((e) -> e.accept(visitor)))
.onEventStream(stream -> stream.subscribe(e -> e.accept(visitor)))
.onComplete(() -> {
EmitResult emitResult = eventSink.tryEmitComplete();

Expand All @@ -559,7 +558,7 @@ public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseS
eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
logger.info("Completed streaming response.");
})
.onError((error) -> {
.onError(error -> {
logger.error("Error handling Bedrock converse stream response", error);
eventSink.tryEmitError(error);
})
Expand All @@ -571,11 +570,20 @@ public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseS

}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
public static final class Builder {

private AwsCredentialsProvider credentialsProvider;

Expand Down Expand Up @@ -696,13 +704,4 @@ public BedrockProxyChatModel build() {

}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
/*
* Copyright 2024 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.bedrock.converse.api;

import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;

import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;

/**
* {@link Usage} implementation for Bedrock Converse API.
*
Expand All @@ -46,17 +47,17 @@ protected BedrockUsage(Long inputTokens, Long outputTokens) {

@Override
public Long getPromptTokens() {
return inputTokens;
return this.inputTokens;
}

@Override
public Long getGenerationTokens() {
return outputTokens;
return this.outputTokens;
}

@Override
public String toString() {
return "BedrockUsage [inputTokens=" + inputTokens + ", outputTokens=" + outputTokens + "]";
return "BedrockUsage [inputTokens=" + this.inputTokens + ", outputTokens=" + this.outputTokens + "]";
}

}
}
Loading