Skip to content

Commit 68fb63c

Browse files
authored
Merge pull request #1122 from andreadimaio/main
Enable tools support in streaming responses for Ollama
2 parents a2f33cb + cb6a4da commit 68fb63c

File tree

8 files changed

+417
-77
lines changed

8 files changed

+417
-77
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
package io.quarkiverse.langchain4j.ollama.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
7+
import static org.junit.jupiter.api.Assertions.assertEquals;
8+
import static org.junit.jupiter.api.Assertions.assertTrue;
9+
import static org.junit.jupiter.api.Assertions.fail;
10+
11+
import java.util.List;
12+
13+
import jakarta.inject.Inject;
14+
import jakarta.inject.Singleton;
15+
16+
import org.jboss.shrinkwrap.api.ShrinkWrap;
17+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
18+
import org.junit.jupiter.api.Test;
19+
import org.junit.jupiter.api.extension.RegisterExtension;
20+
21+
import com.github.tomakehurst.wiremock.stubbing.Scenario;
22+
23+
import dev.langchain4j.agent.tool.Tool;
24+
import dev.langchain4j.data.message.AiMessage;
25+
import dev.langchain4j.data.message.ToolExecutionResultMessage;
26+
import dev.langchain4j.service.MemoryId;
27+
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
28+
import io.quarkiverse.langchain4j.RegisterAiService;
29+
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
30+
import io.quarkus.test.QuarkusUnitTest;
31+
import io.smallrye.mutiny.Multi;
32+
33+
public class OllamaStreamingChatLanguageModelSmokeTest extends WiremockAware {
34+
35+
@RegisterExtension
36+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
37+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(Calculator.class))
38+
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
39+
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false");
40+
41+
@Singleton
42+
@RegisterAiService(tools = Calculator.class)
43+
interface AIServiceWithTool {
44+
Multi<String> streaming(@MemoryId String memoryId, @dev.langchain4j.service.UserMessage String text);
45+
}
46+
47+
@Singleton
48+
@RegisterAiService
49+
interface AIServiceWithoutTool {
50+
Multi<String> streaming(@dev.langchain4j.service.UserMessage String text);
51+
}
52+
53+
@Singleton
54+
static class Calculator {
55+
@Tool("Execute the sum of two numbers")
56+
public int sum(int firstNumber, int secondNumber) {
57+
return firstNumber + secondNumber;
58+
}
59+
}
60+
61+
@Inject
62+
AIServiceWithTool aiServiceWithTool;
63+
64+
@Inject
65+
AIServiceWithoutTool aiServiceWithoutTool;
66+
67+
@Inject
68+
ChatMemoryStore memory;
69+
70+
@Test
71+
void test_1() {
72+
wiremock().register(
73+
post(urlEqualTo("/api/chat"))
74+
.withRequestBody(equalToJson("""
75+
{
76+
"model" : "llama3.2",
77+
"messages" : [ {
78+
"role" : "user",
79+
"content" : "Hello"
80+
}],
81+
"options" : {
82+
"temperature" : 0.8,
83+
"top_k" : 40,
84+
"top_p" : 0.9
85+
},
86+
"stream" : true
87+
}
88+
"""))
89+
.willReturn(aResponse()
90+
.withHeader("Content-Type", "application/x-ndjson")
91+
.withBody(
92+
"""
93+
{"model":"llama3.2","created_at":"2024-11-30T09:03:42.312611426Z","message":{"role":"assistant","content":"Hello"},"done":false}
94+
{"model":"llama3.2","created_at":"2024-11-30T09:03:42.514215351Z","message":{"role":"assistant","content":"!"},"done":false}
95+
{"model":"llama3.2","created_at":"2024-11-30T09:03:44.109059873Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":4821417857,"load_duration":2508844071,"prompt_eval_count":11,"prompt_eval_duration":514000000,"eval_count":10,"eval_duration":1797000000}""")));
96+
97+
var result = aiServiceWithoutTool.streaming("Hello").collect().asList().await().indefinitely();
98+
assertEquals(List.of("Hello", "!"), result);
99+
}
100+
101+
@Test
102+
void test_2() {
103+
wiremock().register(
104+
post(urlEqualTo("/api/chat"))
105+
.withRequestBody(equalToJson("""
106+
{
107+
"model" : "llama3.2",
108+
"messages" : [ {
109+
"role" : "user",
110+
"content" : "Hello"
111+
}],
112+
"tools" : [ {
113+
"type" : "function",
114+
"function" : {
115+
"name" : "sum",
116+
"description" : "Execute the sum of two numbers",
117+
"parameters" : {
118+
"type" : "object",
119+
"properties" : {
120+
"firstNumber" : {
121+
"type" : "integer"
122+
},
123+
"secondNumber" : {
124+
"type" : "integer"
125+
}
126+
},
127+
"required" : [ "firstNumber", "secondNumber" ]
128+
}
129+
}
130+
} ],
131+
"options" : {
132+
"temperature" : 0.8,
133+
"top_k" : 40,
134+
"top_p" : 0.9
135+
},
136+
"stream" : true
137+
}
138+
"""))
139+
.willReturn(aResponse()
140+
.withHeader("Content-Type", "application/x-ndjson")
141+
.withBody(
142+
"""
143+
{"model":"llama3.2","created_at":"2024-11-30T09:03:42.312611426Z","message":{"role":"assistant","content":"Hello"},"done":false}
144+
{"model":"llama3.2","created_at":"2024-11-30T09:03:42.514215351Z","message":{"role":"assistant","content":"!"},"done":false}
145+
{"model":"llama3.2","created_at":"2024-11-30T09:03:44.109059873Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":4821417857,"load_duration":2508844071,"prompt_eval_count":11,"prompt_eval_duration":514000000,"eval_count":10,"eval_duration":1797000000}""")));
146+
147+
var result = aiServiceWithTool.streaming("1", "Hello").collect().asList().await().indefinitely();
148+
assertEquals(List.of("Hello", "!"), result);
149+
}
150+
151+
@Test
152+
void test_3() {
153+
wiremock()
154+
.register(
155+
post(urlEqualTo("/api/chat"))
156+
.inScenario("")
157+
.whenScenarioStateIs(Scenario.STARTED)
158+
.willSetStateTo("TOOL_CALL")
159+
.withRequestBody(equalToJson("""
160+
{
161+
"model" : "llama3.2",
162+
"messages" : [ {
163+
"role" : "user",
164+
"content" : "1 + 1"
165+
}],
166+
"tools" : [ {
167+
"type" : "function",
168+
"function" : {
169+
"name" : "sum",
170+
"description" : "Execute the sum of two numbers",
171+
"parameters" : {
172+
"type" : "object",
173+
"properties" : {
174+
"firstNumber" : {
175+
"type" : "integer"
176+
},
177+
"secondNumber" : {
178+
"type" : "integer"
179+
}
180+
},
181+
"required" : [ "firstNumber", "secondNumber" ]
182+
}
183+
}
184+
} ],
185+
"options" : {
186+
"temperature" : 0.8,
187+
"top_k" : 40,
188+
"top_p" : 0.9
189+
},
190+
"stream" : true
191+
}
192+
"""))
193+
194+
.willReturn(aResponse()
195+
.withHeader("Content-Type", "application/x-ndjson")
196+
.withBody(
197+
"""
198+
{"model":"llama3.1","created_at":"2024-11-30T16:36:02.833930413Z","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"sum","arguments":{"firstNumber":1,"secondNumber":1}}}]},"done":false}
199+
{"model":"llama3.1","created_at":"2024-11-30T16:36:04.368016152Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":28825672145,"load_duration":29961281,"prompt_eval_count":169,"prompt_eval_duration":3906000000,"eval_count":22,"eval_duration":24887000000}""")));
200+
201+
wiremock()
202+
.register(
203+
post(urlEqualTo("/api/chat"))
204+
.inScenario("")
205+
.whenScenarioStateIs("TOOL_CALL")
206+
.willSetStateTo("AI_RESPONSE")
207+
.withRequestBody(equalToJson("""
208+
{
209+
"model" : "llama3.2",
210+
"messages" : [ {
211+
"role" : "user",
212+
"content" : "1 + 1"
213+
}, {
214+
"role" : "assistant",
215+
"tool_calls" : [ {
216+
"function" : {
217+
"name" : "sum",
218+
"arguments" : {
219+
"firstNumber" : 1,
220+
"secondNumber" : 1
221+
}
222+
}
223+
} ]
224+
}, {
225+
"role" : "tool",
226+
"content" : "2"
227+
} ],
228+
"tools" : [ {
229+
"type" : "function",
230+
"function" : {
231+
"name" : "sum",
232+
"description" : "Execute the sum of two numbers",
233+
"parameters" : {
234+
"type" : "object",
235+
"properties" : {
236+
"firstNumber" : {
237+
"type" : "integer"
238+
},
239+
"secondNumber" : {
240+
"type" : "integer"
241+
}
242+
},
243+
"required" : [ "firstNumber", "secondNumber" ]
244+
}
245+
}
246+
} ],
247+
"options" : {
248+
"temperature" : 0.8,
249+
"top_k" : 40,
250+
"top_p" : 0.9
251+
},
252+
"stream" : true
253+
}
254+
"""))
255+
.willReturn(aResponse()
256+
.withHeader("Content-Type", "application/x-ndjson")
257+
.withBody(
258+
"""
259+
{"model":"llama3.1","created_at":"2024-11-30T16:36:04.368016152Z","message":{"role":"assistant","content":"The result is 2"},"done_reason":"stop","done":true,"total_duration":28825672145,"load_duration":29961281,"prompt_eval_count":169,"prompt_eval_duration":3906000000,"eval_count":22,"eval_duration":24887000000}""")));
260+
261+
var result = aiServiceWithTool.streaming("2", "1 + 1").collect().asList().await().indefinitely();
262+
assertEquals(List.of("The result is 2"), result);
263+
264+
var messages = memory.getMessages("2");
265+
assertEquals("1 + 1", ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText());
266+
assertEquals("The result is 2", ((dev.langchain4j.data.message.AiMessage) messages.get(3)).text());
267+
268+
if (messages.get(1) instanceof AiMessage aiMessage) {
269+
assertTrue(aiMessage.hasToolExecutionRequests());
270+
assertEquals("{\"firstNumber\":1,\"secondNumber\":1}", aiMessage.toolExecutionRequests().get(0).arguments());
271+
} else {
272+
fail("The second message is not of type AiMessage");
273+
}
274+
275+
if (messages.get(2) instanceof ToolExecutionResultMessage toolResultMessage) {
276+
assertEquals(2, Integer.parseInt(toolResultMessage.text()));
277+
} else {
278+
fail("The third message is not of type ToolExecutionResultMessage");
279+
}
280+
}
281+
}

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@
66
import static io.quarkiverse.langchain4j.ollama.MessageMapper.toTools;
77

88
import java.time.Duration;
9-
import java.util.ArrayList;
109
import java.util.Collections;
1110
import java.util.List;
1211
import java.util.Map;
1312
import java.util.concurrent.ConcurrentHashMap;
1413

1514
import org.jboss.logging.Logger;
1615

17-
import com.fasterxml.jackson.core.JsonProcessingException;
18-
1916
import dev.langchain4j.agent.tool.ToolExecutionRequest;
2017
import dev.langchain4j.agent.tool.ToolSpecification;
2118
import dev.langchain4j.data.message.AiMessage;
@@ -29,7 +26,6 @@
2926
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
3027
import dev.langchain4j.model.output.Response;
3128
import dev.langchain4j.model.output.TokenUsage;
32-
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
3329

3430
public class OllamaChatLanguageModel implements ChatLanguageModel {
3531

@@ -137,25 +133,10 @@ private static Response<AiMessage> toResponse(ChatResponse response) {
137133
AiMessage.from(response.message().content()),
138134
new TokenUsage(response.promptEvalCount(), response.evalCount()));
139135
} else {
140-
try {
141-
List<ToolExecutionRequest> toolExecutionRequests = new ArrayList<>(toolCalls.size());
142-
for (ToolCall toolCall : toolCalls) {
143-
ToolCall.FunctionCall functionCall = toolCall.function();
144-
145-
// TODO: we need to update LangChain4j to make ToolExecutionRequest use a map instead of a String
146-
String argumentsStr = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER
147-
.writeValueAsString(functionCall.arguments());
148-
toolExecutionRequests.add(ToolExecutionRequest.builder()
149-
.name(functionCall.name())
150-
.arguments(argumentsStr)
151-
.build());
152-
}
153-
154-
result = Response.from(aiMessage(toolExecutionRequests),
155-
new TokenUsage(response.promptEvalCount(), response.evalCount()));
156-
} catch (JsonProcessingException e) {
157-
throw new RuntimeException("Unable to parse tool call response", e);
158-
}
136+
List<ToolExecutionRequest> toolExecutionRequests = toolCalls.stream().map(ToolCall::toToolExecutionRequest)
137+
.toList();
138+
result = Response.from(aiMessage(toolExecutionRequests),
139+
new TokenUsage(response.promptEvalCount(), response.evalCount()));
159140
}
160141
return result;
161142
}

0 commit comments

Comments
 (0)