66
77import java .time .Duration ;
88import java .util .ArrayList ;
9+ import java .util .Collections ;
910import java .util .List ;
11+ import java .util .Map ;
12+ import java .util .concurrent .ConcurrentHashMap ;
1013import java .util .function .Consumer ;
1114import java .util .stream .Collectors ;
1215
16+ import org .jboss .logging .Logger ;
17+
1318import dev .langchain4j .agent .tool .ToolExecutionRequest ;
1419import dev .langchain4j .agent .tool .ToolSpecification ;
1520import dev .langchain4j .data .message .AiMessage ;
1621import dev .langchain4j .data .message .ChatMessage ;
1722import dev .langchain4j .model .StreamingResponseHandler ;
1823import dev .langchain4j .model .chat .StreamingChatLanguageModel ;
24+ import dev .langchain4j .model .chat .listener .ChatModelErrorContext ;
25+ import dev .langchain4j .model .chat .listener .ChatModelListener ;
26+ import dev .langchain4j .model .chat .listener .ChatModelRequest ;
27+ import dev .langchain4j .model .chat .listener .ChatModelRequestContext ;
28+ import dev .langchain4j .model .chat .listener .ChatModelResponse ;
29+ import dev .langchain4j .model .chat .listener .ChatModelResponseContext ;
1930import dev .langchain4j .model .output .Response ;
2031import dev .langchain4j .model .output .TokenUsage ;
2132import io .smallrye .mutiny .Context ;
2435 * Use to have streaming feature on models used trough Ollama.
2536 */
2637public class OllamaStreamingChatLanguageModel implements StreamingChatLanguageModel {
38+
39+ private static final Logger log = Logger .getLogger (OllamaStreamingChatLanguageModel .class );
40+
2741 private static final String TOOLS_CONTEXT = "TOOLS" ;
2842 private static final String TOKEN_USAGE_CONTEXT = "TOKEN_USAGE" ;
2943 private static final String RESPONSE_CONTEXT = "RESPONSE" ;
44+ private static final String MODEL_ID = "MODEL_ID" ;
3045 private final OllamaClient client ;
3146 private final String model ;
3247 private final String format ;
3348 private final Options options ;
49+ private final List <ChatModelListener > listeners ;
3450
3551 private OllamaStreamingChatLanguageModel (OllamaStreamingChatLanguageModel .Builder builder ) {
3652 client = new OllamaClient (builder .baseUrl , builder .timeout , builder .logRequests , builder .logResponses ,
3753 builder .configName , builder .tlsConfigurationName );
3854 model = builder .model ;
3955 format = builder .format ;
4056 options = builder .options ;
57+ this .listeners = builder .listeners ;
4158 }
4259
4360 public static OllamaStreamingChatLanguageModel .Builder builder () {
@@ -60,13 +77,25 @@ public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpe
6077 .build ();
6178
6279 Context context = Context .empty ();
80+ context .put (MODEL_ID , "" );
6381 context .put (RESPONSE_CONTEXT , new ArrayList <ChatResponse >());
6482 context .put (TOOLS_CONTEXT , new ArrayList <ToolExecutionRequest >());
6583
84+ ChatModelRequest modelListenerRequest = createModelListenerRequest (request , messages , toolSpecifications );
85+ Map <Object , Object > attributes = new ConcurrentHashMap <>();
86+ ChatModelRequestContext requestContext = new ChatModelRequestContext (modelListenerRequest , attributes );
87+ listeners .forEach (listener -> {
88+ try {
89+ listener .onRequest (requestContext );
90+ } catch (Exception e ) {
91+ log .warn ("Exception while calling model listener" , e );
92+ }
93+ });
94+
6695 client .streamingChat (request )
6796 .subscribe ()
6897 .with (context ,
69- new Consumer <ChatResponse >() {
98+ new Consumer <>() {
7099 @ Override
71100 @ SuppressWarnings ("unchecked" )
72101 public void accept (ChatResponse response ) {
@@ -89,6 +118,9 @@ public void accept(ChatResponse response) {
89118 }
90119
91120 if (response .done ()) {
121+ if (response .model () != null ) {
122+ context .put (MODEL_ID , response .model ());
123+ }
92124 TokenUsage tokenUsage = new TokenUsage (
93125 response .evalCount (),
94126 response .promptEvalCount (),
@@ -101,9 +133,36 @@ public void accept(ChatResponse response) {
101133 }
102134 }
103135 },
104- new Consumer <Throwable >() {
136+ new Consumer <>() {
105137 @ Override
106138 public void accept (Throwable error ) {
139+ List <ChatResponse > chatResponses = context .get (RESPONSE_CONTEXT );
140+ String stringResponse = chatResponses .stream ()
141+ .map (ChatResponse ::message )
142+ .map (Message ::content )
143+ .collect (Collectors .joining ());
144+ AiMessage aiMessage = new AiMessage (stringResponse );
145+ Response <AiMessage > aiMessageResponse = Response .from (aiMessage );
146+
147+ ChatModelResponse modelListenerPartialResponse = createModelListenerResponse (
148+ null ,
149+ context .get (MODEL_ID ),
150+ aiMessageResponse );
151+
152+ ChatModelErrorContext errorContext = new ChatModelErrorContext (
153+ error ,
154+ modelListenerRequest ,
155+ modelListenerPartialResponse ,
156+ attributes );
157+
158+ listeners .forEach (listener -> {
159+ try {
160+ listener .onError (errorContext );
161+ } catch (Exception e ) {
162+ log .warn ("Exception while calling model listener" , e );
163+ }
164+ });
165+
107166 handler .onError (error );
108167 }
109168 },
@@ -115,22 +174,72 @@ public void run() {
115174 List <ChatResponse > chatResponses = context .get (RESPONSE_CONTEXT );
116175 List <ToolExecutionRequest > toolExecutionRequests = context .get (TOOLS_CONTEXT );
117176
118- if (toolExecutionRequests .size () > 0 ) {
177+ if (! toolExecutionRequests .isEmpty () ) {
119178 handler .onComplete (Response .from (AiMessage .from (toolExecutionRequests ), tokenUsage ));
120179 return ;
121180 }
122181
123- String response = chatResponses .stream ()
182+ String stringResponse = chatResponses .stream ()
124183 .map (ChatResponse ::message )
125184 .map (Message ::content )
126185 .collect (Collectors .joining ());
127186
128- AiMessage message = new AiMessage (response );
129- handler .onComplete (Response .from (message , tokenUsage ));
187+ AiMessage aiMessage = new AiMessage (stringResponse );
188+ Response <AiMessage > aiMessageResponse = Response .from (aiMessage , tokenUsage );
189+
190+ ChatModelResponse modelListenerResponse = createModelListenerResponse (
191+ null ,
192+ context .get (MODEL_ID ),
193+ aiMessageResponse );
194+ ChatModelResponseContext responseContext = new ChatModelResponseContext (
195+ modelListenerResponse ,
196+ modelListenerRequest ,
197+ attributes );
198+ listeners .forEach (listener -> {
199+ try {
200+ listener .onResponse (responseContext );
201+ } catch (Exception e ) {
202+ log .warn ("Exception while calling model listener" , e );
203+ }
204+ });
205+
206+ handler .onComplete (aiMessageResponse );
130207 }
131208 });
132209 }
133210
211+ private ChatModelRequest createModelListenerRequest (ChatRequest request ,
212+ List <ChatMessage > messages ,
213+ List <ToolSpecification > toolSpecifications ) {
214+ Options options = request .options ();
215+ var builder = ChatModelRequest .builder ()
216+ .model (request .model ())
217+ .messages (messages )
218+ .toolSpecifications (toolSpecifications );
219+ if (options != null ) {
220+ builder .temperature (options .temperature ())
221+ .topP (options .topP ())
222+ .maxTokens (options .numPredict ());
223+ }
224+ return builder .build ();
225+ }
226+
227+ private ChatModelResponse createModelListenerResponse (String responseId ,
228+ String responseModel ,
229+ Response <AiMessage > response ) {
230+ if (response == null ) {
231+ return null ;
232+ }
233+
234+ return ChatModelResponse .builder ()
235+ .id (responseId )
236+ .model (responseModel )
237+ .tokenUsage (response .tokenUsage ())
238+ .finishReason (response .finishReason ())
239+ .aiMessage (response .content ())
240+ .build ();
241+ }
242+
134243 @ Override
135244 public void generate (List <ChatMessage > messages , ToolSpecification toolSpecification ,
136245 StreamingResponseHandler <AiMessage > handler ) {
@@ -161,6 +270,7 @@ private Builder() {
161270 private boolean logRequests = false ;
162271 private boolean logResponses = false ;
163272 private String configName ;
273+ private List <ChatModelListener > listeners = Collections .emptyList ();
164274
165275 public Builder baseUrl (String val ) {
166276 baseUrl = val ;
@@ -207,6 +317,11 @@ public Builder configName(String configName) {
207317 return this ;
208318 }
209319
320+ public Builder listeners (List <ChatModelListener > listeners ) {
321+ this .listeners = listeners ;
322+ return this ;
323+ }
324+
210325 public OllamaStreamingChatLanguageModel build () {
211326 return new OllamaStreamingChatLanguageModel (this );
212327 }
0 commit comments