Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
* OpenAI
* Azure OpenAI
* Amazon Bedrock (Anthropic, Llama2, Cohere, Titan, Jurassic2)
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
* HuggingFace
* Google VertexAI (PaLM2, Gemini)
* Mistral AI
Expand Down
2 changes: 1 addition & 1 deletion models/spring-ai-bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- [Anthropic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-anthropic.html)
- [Cohere Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-cohere.html)
- [Cohere Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-cohere-embedding.html)
- [Llama2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama2.html)
- [Llama Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama.html)
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
Expand Down Expand Up @@ -63,9 +63,9 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class))
hints.reflection().registerType(tr, mcs);

for (var tr : findJsonAnnotatedClassesInPackage(Llama2ChatBedrockApi.class))
for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlama2ChatOptions.class))
for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class))
hints.reflection().registerType(tr, mcs);

for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.llama2;
package org.springframework.ai.bedrock.llama;

import java.util.List;

import reactor.core.publisher.Flux;

import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
Expand All @@ -35,26 +35,27 @@
import org.springframework.util.Assert;

/**
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat
* generative.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockLlama2ChatClient implements ChatClient, StreamingChatClient {
public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {

private final Llama2ChatBedrockApi chatApi;
private final LlamaChatBedrockApi chatApi;

private final BedrockLlama2ChatOptions defaultOptions;
private final BedrockLlamaChatOptions defaultOptions;

public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi) {
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi) {
this(chatApi,
BedrockLlama2ChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
}

public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi, BedrockLlama2ChatOptions options) {
Assert.notNull(chatApi, "Llama2ChatBedrockApi must not be null");
Assert.notNull(options, "BedrockLlama2ChatOptions must not be null");
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null");
Assert.notNull(options, "BedrockLlamaChatOptions must not be null");

this.chatApi = chatApi;
this.defaultOptions = options;
Expand All @@ -65,7 +66,7 @@ public ChatResponse call(Prompt prompt) {

var request = createRequest(prompt);

Llama2ChatResponse response = this.chatApi.chatCompletion(request);
LlamaChatResponse response = this.chatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata(
ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response)))));
Expand All @@ -76,7 +77,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {

var request = createRequest(prompt);

Flux<Llama2ChatResponse> fluxResponse = this.chatApi.chatCompletionStream(request);
Flux<LlamaChatResponse> fluxResponse = this.chatApi.chatCompletionStream(request);

return fluxResponse.map(response -> {
String stopReason = response.stopReason() != null ? response.stopReason().name() : null;
Expand All @@ -85,7 +86,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

private Usage extractUsage(Llama2ChatResponse response) {
private Usage extractUsage(LlamaChatResponse response) {
return new Usage() {

@Override
Expand All @@ -103,22 +104,22 @@ public Long getGenerationTokens() {
/**
* Accessible for testing.
*/
Llama2ChatRequest createRequest(Prompt prompt) {
LlamaChatRequest createRequest(Prompt prompt) {

final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());

Llama2ChatRequest request = Llama2ChatRequest.builder(promptValue).build();
LlamaChatRequest request = LlamaChatRequest.builder(promptValue).build();

if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, Llama2ChatRequest.class);
request = ModelOptionsUtils.merge(request, this.defaultOptions, LlamaChatRequest.class);
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
BedrockLlama2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, BedrockLlama2ChatOptions.class);
BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, BedrockLlamaChatOptions.class);

request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Llama2ChatRequest.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.llama2;
package org.springframework.ai.bedrock.llama;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
Expand All @@ -26,7 +26,7 @@
* @author Christian Tzolov
*/
@JsonInclude(Include.NON_NULL)
public class BedrockLlama2ChatOptions implements ChatOptions {
public class BedrockLlamaChatOptions implements ChatOptions {

/**
* The temperature value controls the randomness of the generated text. Use a lower
Expand All @@ -51,7 +51,7 @@ public static Builder builder() {

public static class Builder {

private BedrockLlama2ChatOptions options = new BedrockLlama2ChatOptions();
private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions();

public Builder withTemperature(Float temperature) {
this.options.setTemperature(temperature);
Expand All @@ -68,7 +68,7 @@ public Builder withMaxGenLen(Integer maxGenLen) {
return this;
}

public BedrockLlama2ChatOptions build() {
public BedrockLlamaChatOptions build() {
return this.options;
}

Expand Down
Loading