Skip to content

Commit d2a0f7e

Browse files
authored
Merge pull request #1150 from andreadimaio/main
Add support for structured output in Ollama
2 parents e7b1bf4 + cead717 commit d2a0f7e

File tree

8 files changed

+382
-25
lines changed

8 files changed

+382
-25
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package io.quarkiverse.langchain4j.ollama.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
7+
import static org.junit.jupiter.api.Assertions.assertEquals;
8+
9+
import jakarta.inject.Inject;
10+
import jakarta.inject.Singleton;
11+
12+
import org.jboss.shrinkwrap.api.ShrinkWrap;
13+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
14+
import org.junit.jupiter.api.Test;
15+
import org.junit.jupiter.api.extension.RegisterExtension;
16+
17+
import dev.langchain4j.model.output.structured.Description;
18+
import dev.langchain4j.service.UserName;
19+
import io.quarkiverse.langchain4j.RegisterAiService;
20+
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
21+
import io.quarkus.test.QuarkusUnitTest;
22+
23+
public class OllamaJsonOutputTest extends WiremockAware {
24+
25+
@RegisterExtension
26+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
27+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
28+
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
29+
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false")
30+
.overrideConfigKey("quarkus.langchain4j.ollama.chat-model.format", "json");
31+
32+
@Description("A person")
33+
public record Person(
34+
@Description("The firstname") String firstname,
35+
@Description("The lastname") String lastname) {
36+
}
37+
38+
@Singleton
39+
@RegisterAiService
40+
interface AiService {
41+
Person extractPerson(@UserName String text);
42+
}
43+
44+
@Inject
45+
AiService aiService;
46+
47+
@Test
48+
void extract() {
49+
wiremock().register(
50+
post(urlEqualTo("/api/chat"))
51+
.withRequestBody(equalToJson(
52+
"""
53+
{
54+
"model": "llama3.2",
55+
"messages": [
56+
{
57+
"role": "user",
58+
"content": "Tell me something about Alan Wake\\nYou must answer strictly in the following JSON format: {\\n\\\"firstname\\\": (The firstname; type: string),\\n\\\"lastname\\\": (The lastname; type: string)\\n}"
59+
}
60+
],
61+
"stream": false,
62+
"options": {
63+
"temperature": 0.8,
64+
"top_k": 40,
65+
"top_p": 0.9
66+
},
67+
"tools": [],
68+
"format": "json"
69+
}"""))
70+
.willReturn(aResponse()
71+
.withHeader("Content-Type", "application/json")
72+
.withBody("""
73+
{
74+
"model": "llama3.2",
75+
"created_at": "2024-12-11T15:21:23.422542932Z",
76+
"message": {
77+
"role": "assistant",
78+
"content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}"
79+
},
80+
"done_reason": "stop",
81+
"done": true,
82+
"total_duration": 8125806496,
83+
"load_duration": 4223887064,
84+
"prompt_eval_count": 31,
85+
"prompt_eval_duration": 1331000000,
86+
"eval_count": 18,
87+
"eval_duration": 2569000000
88+
}""")));
89+
90+
var result = aiService.extractPerson("Tell me something about Alan Wake");
91+
assertEquals(new Person("Alan", "Wake"), result);
92+
}
93+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package io.quarkiverse.langchain4j.ollama.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
7+
import static org.junit.jupiter.api.Assertions.assertEquals;
8+
9+
import jakarta.inject.Inject;
10+
import jakarta.inject.Singleton;
11+
12+
import org.jboss.shrinkwrap.api.ShrinkWrap;
13+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
14+
import org.junit.jupiter.api.Test;
15+
import org.junit.jupiter.api.extension.RegisterExtension;
16+
17+
import dev.langchain4j.model.output.structured.Description;
18+
import dev.langchain4j.service.UserName;
19+
import io.quarkiverse.langchain4j.RegisterAiService;
20+
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
21+
import io.quarkus.test.QuarkusUnitTest;
22+
23+
public class OllamaStructuredOutputTest extends WiremockAware {
24+
25+
@RegisterExtension
26+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
27+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
28+
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
29+
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false");
30+
31+
@Description("A person")
32+
public record Person(
33+
@Description("The firstname") String firstname,
34+
@Description("The lastname") String lastname) {
35+
}
36+
37+
@Singleton
38+
@RegisterAiService
39+
interface AiService {
40+
Person extractPerson(@UserName String text);
41+
}
42+
43+
@Inject
44+
AiService aiService;
45+
46+
@Test
47+
void extract() {
48+
wiremock().register(
49+
post(urlEqualTo("/api/chat"))
50+
.withRequestBody(equalToJson("""
51+
{
52+
"model": "llama3.2",
53+
"messages": [{"role": "user", "content": "Tell me something about Alan Wake"}],
54+
"stream": false,
55+
"options" : {
56+
"temperature" : 0.8,
57+
"top_k" : 40,
58+
"top_p" : 0.9
59+
},
60+
"format": {
61+
"type": "object",
62+
"description": "A person",
63+
"properties": {
64+
"firstname": {
65+
"description": "The firstname",
66+
"type": "string"
67+
},
68+
"lastname": {
69+
"description": "The lastname",
70+
"type": "string"
71+
}
72+
},
73+
"required": [
74+
"firstname",
75+
"lastname"
76+
]
77+
}
78+
}
79+
"""))
80+
.willReturn(aResponse()
81+
.withHeader("Content-Type", "application/json")
82+
.withBody("""
83+
{
84+
"model": "llama3.2",
85+
"created_at": "2024-12-11T15:21:23.422542932Z",
86+
"message": {
87+
"role": "assistant",
88+
"content": "{\\\"firstname\\\":\\\"Alan\\\",\\\"lastname\\\":\\\"Wake\\\"}"
89+
},
90+
"done_reason": "stop",
91+
"done": true,
92+
"total_duration": 8125806496,
93+
"load_duration": 4223887064,
94+
"prompt_eval_count": 31,
95+
"prompt_eval_duration": 1331000000,
96+
"eval_count": 18,
97+
"eval_duration": 2569000000
98+
}""")));
99+
100+
var result = aiService.extractPerson("Tell me something about Alan Wake");
101+
assertEquals(new Person("Alan", "Wake"), result);
102+
}
103+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package io.quarkiverse.langchain4j.ollama.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
7+
import static org.junit.jupiter.api.Assertions.assertEquals;
8+
9+
import jakarta.inject.Inject;
10+
import jakarta.inject.Singleton;
11+
12+
import org.jboss.shrinkwrap.api.ShrinkWrap;
13+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
14+
import org.junit.jupiter.api.Test;
15+
import org.junit.jupiter.api.extension.RegisterExtension;
16+
17+
import dev.langchain4j.service.UserName;
18+
import io.quarkiverse.langchain4j.RegisterAiService;
19+
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
20+
import io.quarkus.test.QuarkusUnitTest;
21+
22+
public class OllamaTextOutputTest extends WiremockAware {
23+
24+
@RegisterExtension
25+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
26+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
27+
.overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig())
28+
.overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false");
29+
30+
@Singleton
31+
@RegisterAiService
32+
interface AiService {
33+
String question(@UserName String text);
34+
}
35+
36+
@Inject
37+
AiService aiService;
38+
39+
@Test
40+
void extract() {
41+
wiremock().register(
42+
post(urlEqualTo("/api/chat"))
43+
.withRequestBody(equalToJson(
44+
"""
45+
{
46+
"model": "llama3.2",
47+
"messages": [
48+
{
49+
"role": "user",
50+
"content": "Tell me something about Alan Wake"
51+
}
52+
],
53+
"stream": false,
54+
"options": {
55+
"temperature": 0.8,
56+
"top_k": 40,
57+
"top_p": 0.9
58+
},
59+
"tools": []
60+
}"""))
61+
.willReturn(aResponse()
62+
.withHeader("Content-Type", "application/json")
63+
.withBody("""
64+
{
65+
"model": "llama3.2",
66+
"created_at": "2024-12-11T15:21:23.422542932Z",
67+
"message": {
68+
"role": "assistant",
69+
"content": "He is a writer!"
70+
},
71+
"done_reason": "stop",
72+
"done": true,
73+
"total_duration": 8125806496,
74+
"load_duration": 4223887064,
75+
"prompt_eval_count": 31,
76+
"prompt_eval_duration": 1331000000,
77+
"eval_count": 18,
78+
"eval_duration": 2569000000
79+
}""")));
80+
81+
var result = aiService.question("Tell me something about Alan Wake");
82+
assertEquals("He is a writer!", result);
83+
}
84+
}

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ChatRequest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
import java.util.List;
44

5-
public record ChatRequest(String model, List<Message> messages, List<Tool> tools, Options options, String format,
5+
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
6+
7+
public record ChatRequest(
8+
String model,
9+
List<Message> messages,
10+
List<Tool> tools,
11+
Options options,
12+
@JsonSerialize(using = FormatJsonSerializer.class) String format,
613
Boolean stream) {
714

815
public static Builder builder() {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package io.quarkiverse.langchain4j.ollama;
2+
3+
import java.io.IOException;
4+
5+
import com.fasterxml.jackson.core.JsonGenerator;
6+
import com.fasterxml.jackson.databind.JsonSerializer;
7+
import com.fasterxml.jackson.databind.SerializerProvider;
8+
9+
public class FormatJsonSerializer extends JsonSerializer<String> {
10+
11+
@Override
12+
public void serialize(String value, JsonGenerator gen, SerializerProvider serializers) throws IOException {
13+
if (value == null)
14+
return;
15+
else if (value.startsWith("{") && value.endsWith("}"))
16+
gen.writeRawValue(value);
17+
else
18+
gen.writeString(value);
19+
}
20+
}

0 commit comments

Comments
 (0)