11/*
2- * Copyright 2023-2023 the original author or authors.
2+ * Copyright 2023-2024 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1818
1919import java .util .List ;
2020
21- import org .springframework .ai .chat .ChatClient ;
22- import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
2321import reactor .core .publisher .Flux ;
2422
2523import org .springframework .ai .bedrock .MessageToPromptConverter ;
2624import org .springframework .ai .bedrock .titan .api .TitanChatBedrockApi ;
2725import org .springframework .ai .bedrock .titan .api .TitanChatBedrockApi .TitanChatRequest ;
2826import org .springframework .ai .bedrock .titan .api .TitanChatBedrockApi .TitanChatResponse ;
2927import org .springframework .ai .bedrock .titan .api .TitanChatBedrockApi .TitanChatResponseChunk ;
28+ import org .springframework .ai .chat .ChatClient ;
29+ import org .springframework .ai .chat .ChatOptions ;
3030import org .springframework .ai .chat .ChatResponse ;
31- import org .springframework .ai .chat .StreamingChatClient ;
3231import org .springframework .ai .chat .Generation ;
32+ import org .springframework .ai .chat .StreamingChatClient ;
33+ import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
3334import org .springframework .ai .chat .metadata .Usage ;
3435import org .springframework .ai .chat .prompt .Prompt ;
36+ import org .springframework .ai .model .ModelOptionsUtils ;
37+ import org .springframework .util .Assert ;
3538
3639/**
3740 * @author Christian Tzolov
@@ -41,41 +44,22 @@ public class BedrockTitanChatClient implements ChatClient, StreamingChatClient {
4144
4245 private final TitanChatBedrockApi chatApi ;
4346
44- private Float temperature ;
45-
46- private Float topP ;
47-
48- private Integer maxTokenCount ;
49-
50- private List <String > stopSequences ;
47+ private final BedrockTitanChatOptions defaultOptions ;
5148
5249 public BedrockTitanChatClient (TitanChatBedrockApi chatApi ) {
53- this . chatApi = chatApi ;
50+ this ( chatApi , BedrockTitanChatOptions . builder (). withTemperature ( 0.8f ). build ()) ;
5451 }
5552
56- public BedrockTitanChatClient withTemperature (Float temperature ) {
57- this .temperature = temperature ;
58- return this ;
59- }
60-
61- public BedrockTitanChatClient withTopP (Float topP ) {
62- this .topP = topP ;
63- return this ;
64- }
65-
66- public BedrockTitanChatClient withMaxTokenCount (Integer maxTokens ) {
67- this .maxTokenCount = maxTokens ;
68- return this ;
69- }
70-
71- public BedrockTitanChatClient withStopSequences (List <String > stopSequences ) {
72- this .stopSequences = stopSequences ;
73- return this ;
53+ public BedrockTitanChatClient (TitanChatBedrockApi chatApi , BedrockTitanChatOptions defaultOptions ) {
54+ Assert .notNull (chatApi , "ChatApi must not be null" );
55+ Assert .notNull (defaultOptions , "DefaultOptions must not be null" );
56+ this .chatApi = chatApi ;
57+ this .defaultOptions = defaultOptions ;
7458 }
7559
7660 @ Override
7761 public ChatResponse call (Prompt prompt ) {
78- TitanChatResponse response = this .chatApi .chatCompletion (this .createRequest (prompt , false ));
62+ TitanChatResponse response = this .chatApi .chatCompletion (this .createRequest (prompt ));
7963 List <Generation > generations = response .results ().stream ().map (result -> {
8064 return new Generation (result .outputText ());
8165 }).toList ();
@@ -85,7 +69,7 @@ public ChatResponse call(Prompt prompt) {
8569
8670 @ Override
8771 public Flux <ChatResponse > stream (Prompt prompt ) {
88- return this .chatApi .chatCompletionStream (this .createRequest (prompt , true )).map (chunk -> {
72+ return this .chatApi .chatCompletionStream (this .createRequest (prompt )).map (chunk -> {
8973
9074 Generation generation = new Generation (chunk .outputText ());
9175
@@ -104,15 +88,48 @@ else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount(
10488 });
10589 }
10690
107- private TitanChatRequest createRequest (Prompt prompt , boolean stream ) {
91+ /**
92+ * Test access.
93+ */
94+ TitanChatRequest createRequest (Prompt prompt ) {
10895 final String promptValue = MessageToPromptConverter .create ().toPrompt (prompt .getInstructions ());
10996
110- return TitanChatRequest .builder (promptValue )
111- .withTemperature (this .temperature )
112- .withTopP (this .topP )
113- .withMaxTokenCount (this .maxTokenCount )
114- .withStopSequences (this .stopSequences )
115- .build ();
97+ var requestBuilder = TitanChatRequest .builder (promptValue );
98+
99+ if (this .defaultOptions != null ) {
100+ requestBuilder = update (requestBuilder , this .defaultOptions );
101+ }
102+
103+ if (prompt .getOptions () != null ) {
104+ if (prompt .getOptions () instanceof ChatOptions runtimeOptions ) {
105+ BedrockTitanChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
106+ ChatOptions .class , BedrockTitanChatOptions .class );
107+
108+ requestBuilder = update (requestBuilder , updatedRuntimeOptions );
109+ }
110+ else {
111+ throw new IllegalArgumentException ("Prompt options are not of type ChatOptions: "
112+ + prompt .getOptions ().getClass ().getSimpleName ());
113+ }
114+ }
115+
116+ return requestBuilder .build ();
117+ }
118+
119+ private TitanChatRequest .Builder update (TitanChatRequest .Builder builder , BedrockTitanChatOptions options ) {
120+ if (options .getTemperature () != null ) {
121+ builder .withTemperature (options .getTemperature ());
122+ }
123+ if (options .getTopP () != null ) {
124+ builder .withTopP (options .getTopP ());
125+ }
126+ if (options .getMaxTokenCount () != null ) {
127+ builder .withMaxTokenCount (options .getMaxTokenCount ());
128+ }
129+ if (options .getStopSequences () != null ) {
130+ builder .withStopSequences (options .getStopSequences ());
131+ }
132+ return builder ;
116133 }
117134
118135 private Usage extractUsage (TitanChatResponseChunk response ) {
0 commit comments