Skip to content

Commit 4751df3

Browse files
committed
Add structured output to OpenAI
1 parent fdcace4 commit 4751df3

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package org.acme.examples.aiservices;
2+
3+
import static java.time.Month.JULY;
4+
import static org.assertj.core.api.Assertions.assertThat;
5+
import static org.assertj.core.api.InstanceOfAssertFactories.map;
6+
7+
import java.io.IOException;
8+
import java.time.LocalDate;
9+
import java.util.Map;
10+
11+
import jakarta.enterprise.context.ApplicationScoped;
12+
import jakarta.inject.Inject;
13+
14+
import org.jboss.shrinkwrap.api.ShrinkWrap;
15+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
16+
import org.junit.jupiter.api.Test;
17+
import org.junit.jupiter.api.extension.RegisterExtension;
18+
19+
import dev.langchain4j.service.UserMessage;
20+
import io.quarkiverse.langchain4j.RegisterAiService;
21+
import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest;
22+
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
23+
import io.quarkus.test.QuarkusUnitTest;
24+
25+
public class StructuredOutputResponseTest extends OpenAiBaseTest {
26+
27+
@RegisterExtension
28+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
29+
.setArchiveProducer(
30+
() -> ShrinkWrap.create(JavaArchive.class))
31+
.overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever")
32+
.overrideRuntimeConfigKey("quarkus.langchain4j.openai.chat-model.response-format", "json_schema")
33+
.overrideRuntimeConfigKey("quarkus.langchain4j.openai.chat-model.strict-json-schema", "true")
34+
.overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url",
35+
WiremockAware.wiremockUrlForConfig("/v1"));
36+
37+
record Person(String firstName, String lastName, LocalDate birthDate) {
38+
}
39+
40+
@RegisterAiService
41+
@ApplicationScoped
42+
interface PersonExtractor {
43+
44+
@UserMessage("Extract information about a person from {{it}}")
45+
Person extractPersonFrom(String text);
46+
}
47+
48+
@Inject
49+
PersonExtractor personExtractor;
50+
51+
@Test
52+
public void testPojo() throws IOException {
53+
setChatCompletionMessageContent(
54+
// this is supposed to be a string inside a json string hence all the escaping...
55+
"{\\n\\\"firstName\\\": \\\"John\\\",\\n\\\"lastName\\\": \\\"Doe\\\",\\n\\\"birthDate\\\": \\\"1968-07-04\\\"\\n}");
56+
57+
String text = "In 1968, amidst the fading echoes of Independence Day, "
58+
+ "a child named John arrived under the calm evening sky. "
59+
+ "This newborn, bearing the surname Doe, marked the start of a new journey.";
60+
61+
Person result = personExtractor.extractPersonFrom(text);
62+
63+
assertThat(result.firstName).isEqualTo("John");
64+
assertThat(result.lastName).isEqualTo("Doe");
65+
assertThat(result.birthDate).isEqualTo(LocalDate.of(1968, JULY, 4));
66+
67+
Map<String, Object> requestAsMap = getRequestAsMap();
68+
assertSingleRequestMessage(requestAsMap,
69+
"Extract information about a person from In 1968, amidst the fading echoes of Independence Day, " +
70+
"a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, " +
71+
"marked the start of a new journey.");
72+
assertThat(requestAsMap).hasEntrySatisfying("response_format", (v) -> {
73+
assertThat(v).asInstanceOf(map(String.class, Object.class)).satisfies(responseFormatMap -> {
74+
assertThat(responseFormatMap).containsEntry("type", "json_schema");
75+
assertThat(responseFormatMap).extracting("json_schema").satisfies(js -> {
76+
assertThat(js).asInstanceOf(map(String.class, Object.class)).satisfies(jsonSchemaMap -> {
77+
assertThat(jsonSchemaMap).containsEntry("name", "Person").containsKey("schema");
78+
});
79+
});
80+
});
81+
});
82+
83+
}
84+
}

model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel
8585
.presencePenalty(chatModelConfig.presencePenalty())
8686
.frequencyPenalty(chatModelConfig.frequencyPenalty())
8787
.responseFormat(chatModelConfig.responseFormat().orElse(null))
88+
.strictJsonSchema(chatModelConfig.strictJsonSchema().orElse(null))
8889
.stop(chatModelConfig.stop().orElse(null));
8990

9091
openAiConfig.organizationId().ifPresent(builder::organizationId);

model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ public interface ChatModelConfig {
7676
*/
7777
Optional<String> responseFormat();
7878

79+
/**
80+
* Whether responses follow JSON Schema for Structured Outputs
81+
*/
82+
Optional<Boolean> strictJsonSchema();
83+
7984
/**
8085
* The list of stop words to use.
8186
*

0 commit comments

Comments
 (0)