Skip to content

Commit 69979be

Browse files
committed
Make sure to avoid generating the schema if the quarkus.langchain4j.response-schema property is set to false
1 parent 37bd953 commit 69979be

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC
464464
audit.initialMessages(systemMessage, userMessage);
465465
}
466466

467-
//TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image?
467+
// TODO: does it make sense to use the retrievalAugmentor here? What good would be for us telling the LLM to use this or that information to create an image?
468468
AugmentationResult augmentationResult = null;
469469

470470
// TODO: we can only support input guardrails for now as it is tied to AiMessage
@@ -644,14 +644,16 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
644644
.formatted(ResponseSchemaUtil.placeholder(), createInfo.getInterfaceName()));
645645
}
646646

647-
// No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage.
648-
if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) {
649-
templateText = templateText.concat(ResponseSchemaUtil.placeholder());
647+
if (createInfo.getResponseSchemaInfo().enabled()) {
648+
// No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage.
649+
if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) {
650+
templateText = templateText.concat(ResponseSchemaUtil.placeholder());
651+
}
652+
653+
templateVariables.put(ResponseSchemaUtil.templateParam(),
654+
createInfo.getResponseSchemaInfo().outputFormatInstructions());
650655
}
651656

652-
// we do not need to apply the instructions as they have already been added to the template text at build time
653-
templateVariables.put(ResponseSchemaUtil.templateParam(),
654-
createInfo.getResponseSchemaInfo().outputFormatInstructions());
655657
Prompt prompt = PromptTemplate.from(templateText).apply(templateVariables);
656658
return createUserMessage(userName, imageContent, prompt.text());
657659

@@ -667,7 +669,8 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
667669

668670
String text = toString(argValue);
669671
return createUserMessage(userName, imageContent,
670-
text.concat(supportsJsonSchema ? "" : createInfo.getResponseSchemaInfo().outputFormatInstructions()));
672+
text.concat(supportsJsonSchema || !createInfo.getResponseSchemaInfo().enabled() ? ""
673+
: createInfo.getResponseSchemaInfo().outputFormatInstructions()));
671674
} else {
672675
throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName()
673676
+ "'. Please contact the maintainers");
Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
package io.quarkiverse.langchain4j.watsonx.deployment;
22

3+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath;
35
import static org.junit.jupiter.api.Assertions.assertEquals;
46
import static org.junit.jupiter.api.Assertions.assertThrows;
57

8+
import java.util.Date;
9+
610
import jakarta.inject.Inject;
711
import jakarta.inject.Singleton;
812

@@ -11,38 +15,112 @@
1115
import org.junit.jupiter.api.Test;
1216
import org.junit.jupiter.api.extension.RegisterExtension;
1317

18+
import dev.langchain4j.service.SystemMessage;
1419
import dev.langchain4j.service.UserMessage;
1520
import dev.langchain4j.service.V;
1621
import io.quarkiverse.langchain4j.RegisterAiService;
1722
import io.quarkus.test.QuarkusUnitTest;
1823

19-
public class ResponseSchemaOffTest {
24+
public class ResponseSchemaOffTest extends WireMockAbstract {
2025

2126
@RegisterExtension
2227
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
2328
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.base-url", WireMockUtil.URL_WATSONX_SERVER)
2429
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.iam.base-url", WireMockUtil.URL_IAM_SERVER)
2530
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.api-key", WireMockUtil.API_KEY)
2631
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.project-id", WireMockUtil.PROJECT_ID)
32+
.overrideConfigKey("quarkus.langchain4j.watsonx.mode", "generation")
2733
.overrideConfigKey("quarkus.langchain4j.response-schema", "false")
2834
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));
2935

30-
@RegisterAiService
36+
@Override
37+
void handlerBeforeEach() {
38+
mockServers.mockIAMBuilder(200)
39+
.grantType(langchain4jWatsonConfig.defaultConfig().iam().grantType())
40+
.response(WireMockUtil.BEARER_TOKEN, new Date())
41+
.build();
42+
}
43+
44+
@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
3145
@Singleton
3246
interface OnMethodAIService {
33-
String poem(@UserMessage String message, @V("topic") String topic);
47+
String poem1(@UserMessage String message, @V("topic") String topic);
48+
49+
Poem poem2(@UserMessage String message);
50+
51+
@UserMessage("{message}")
52+
Poem poem3(String message);
53+
54+
@SystemMessage("SystemMessage")
55+
@UserMessage("{message}")
56+
Poem poem4(String message);
57+
58+
public record Poem(String text) {
59+
};
3460
}
3561

3662
@Inject
3763
OnMethodAIService onMethodAIService;
3864

65+
static String POEM_RESPONSE = """
66+
{
67+
"model_id": "mistralai/mistral-large",
68+
"created_at": "2024-01-21T17:06:14.052Z",
69+
"results": [
70+
{
71+
"generated_text": "{ \\\"text\\\": \\\"Poem\\\" }",
72+
"generated_token_count": 5,
73+
"input_token_count": 50,
74+
"stop_reason": "eos_token",
75+
"seed": 2123876088
76+
}
77+
]
78+
}
79+
""";
80+
3981
@Test
40-
void on_method_ai_service() throws Exception {
82+
void test_poem_1() throws Exception {
4183
var ex = assertThrows(RuntimeException.class,
42-
() -> onMethodAIService.poem("{response_schema} Generate a poem about {topic}", "dog"));
84+
() -> onMethodAIService.poem1("{response_schema} Generate a poem about {topic}", "dog"));
4385
assertEquals(
4486
"The {response_schema} placeholder cannot be used if the property quarkus.langchain4j.response-schema is set to false. Found in: io.quarkiverse.langchain4j.watsonx.deployment.ResponseSchemaOffTest$OnMethodAIService",
4587
ex.getMessage());
88+
89+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
90+
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
91+
.response(WireMockUtil.RESPONSE_WATSONX_GENERATION_API)
92+
.build();
93+
94+
assertEquals("AI Response", onMethodAIService.poem1("Generate a poem about {topic}", "dog"));
4695
}
4796

97+
@Test
98+
void test_poem_2() {
99+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
100+
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
101+
.response(POEM_RESPONSE)
102+
.build();
103+
104+
assertEquals("Poem", onMethodAIService.poem2("Generate a poem about dog").text);
105+
}
106+
107+
@Test
108+
void test_poem_3() {
109+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
110+
.body(matchingJsonPath("$.input", equalTo("Generate a poem about dog")))
111+
.response(POEM_RESPONSE)
112+
.build();
113+
114+
assertEquals("Poem", onMethodAIService.poem3("Generate a poem about dog").text);
115+
}
116+
117+
@Test
118+
void test_poem_4() {
119+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_GENERATION_API, 200)
120+
.body(matchingJsonPath("$.input", equalTo("SystemMessage\nGenerate a poem about dog")))
121+
.response(POEM_RESPONSE)
122+
.build();
123+
124+
assertEquals("Poem", onMethodAIService.poem4("Generate a poem about dog").text);
125+
}
48126
}

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/WireMockUtil.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import com.github.tomakehurst.wiremock.WireMockServer;
1717
import com.github.tomakehurst.wiremock.client.MappingBuilder;
18+
import com.github.tomakehurst.wiremock.matching.StringValuePattern;
19+
import com.github.tomakehurst.wiremock.stubbing.StubMapping;
1820

1921
import dev.langchain4j.data.message.AiMessage;
2022
import dev.langchain4j.model.StreamingResponseHandler;
@@ -282,6 +284,11 @@ public WatsonxBuilder body(String body) {
282284
return this;
283285
}
284286

287+
public WatsonxBuilder body(StringValuePattern stringValuePattern) {
288+
builder.withRequestBody(stringValuePattern);
289+
return this;
290+
}
291+
285292
public WatsonxBuilder token(String token) {
286293
this.token = token;
287294
return this;
@@ -297,8 +304,8 @@ public WatsonxBuilder response(String response) {
297304
return this;
298305
}
299306

300-
public void build() {
301-
watsonServer.stubFor(
307+
public StubMapping build() {
308+
return watsonServer.stubFor(
302309
builder
303310
.withHeader("Authorization", equalTo("Bearer %s".formatted(token)))
304311
.willReturn(aResponse()

0 commit comments

Comments
 (0)