Skip to content

Commit 78aec0b

Browse files
andreadimaiogeoand
authored andcommitted
Add tests
1 parent b91f56d commit 78aec0b

File tree

6 files changed

+713
-0
lines changed

6 files changed

+713
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package com.ibm.generativeai.bam.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
7+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
8+
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
import java.util.List;
12+
13+
import jakarta.inject.Inject;
14+
import jakarta.inject.Singleton;
15+
import jakarta.ws.rs.core.MediaType;
16+
17+
import org.jboss.shrinkwrap.api.ShrinkWrap;
18+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
19+
import org.junit.jupiter.api.AfterAll;
20+
import org.junit.jupiter.api.BeforeAll;
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.extension.RegisterExtension;
23+
24+
import com.fasterxml.jackson.databind.ObjectMapper;
25+
import com.github.tomakehurst.wiremock.WireMockServer;
26+
27+
import dev.langchain4j.service.SystemMessage;
28+
import dev.langchain4j.service.UserMessage;
29+
import io.quarkiverse.langchain4j.RegisterAiService;
30+
import io.quarkiverse.langchain4j.bam.BamRestApi;
31+
import io.quarkiverse.langchain4j.bam.Message;
32+
import io.quarkiverse.langchain4j.bam.Parameters;
33+
import io.quarkiverse.langchain4j.bam.TextGenerationRequest;
34+
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
35+
import io.quarkus.test.QuarkusUnitTest;
36+
37+
public class AiServiceTest {
38+
39+
static WireMockServer wireMockServer;
40+
static ObjectMapper mapper;
41+
42+
@RegisterExtension
43+
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
44+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", Util.URL)
45+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", Util.API_KEY)
46+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(Util.class));
47+
48+
@BeforeAll
49+
static void beforeAll() {
50+
wireMockServer = new WireMockServer(options().port(Util.PORT));
51+
wireMockServer.start();
52+
mapper = BamRestApi.objectMapper(new ObjectMapper());
53+
}
54+
55+
@AfterAll
56+
static void afterAll() {
57+
wireMockServer.stop();
58+
}
59+
60+
@RegisterAiService
61+
@Singleton
62+
interface NewAIService {
63+
64+
@SystemMessage("This is a systemMessage")
65+
@UserMessage("This is a userMessage {text}")
66+
String chat(String text);
67+
}
68+
69+
@Inject
70+
NewAIService service;
71+
72+
@Inject
73+
Langchain4jBamConfig config;
74+
75+
@Test
76+
void chat() throws Exception {
77+
78+
var modelId = config.chatModel().modelId();
79+
80+
var parameters = Parameters.builder()
81+
.decodingMethod(config.chatModel().decodingMethod())
82+
.temperature(config.chatModel().temperature())
83+
.minNewTokens(config.chatModel().minNewTokens())
84+
.maxNewTokens(config.chatModel().maxNewTokens())
85+
.build();
86+
87+
List<Message> messages = List.of(
88+
new Message("system", "This is a systemMessage"),
89+
new Message("user", "This is a userMessage Hello"));
90+
91+
var body = new TextGenerationRequest(modelId, messages, parameters);
92+
93+
wireMockServer.stubFor(
94+
post(urlEqualTo(Util.URL_CHAT_API.formatted(config.version())))
95+
.withHeader("Authorization", equalTo("Bearer %s".formatted(Util.API_KEY)))
96+
.withRequestBody(equalToJson(mapper.writeValueAsString(body)))
97+
.willReturn(
98+
aResponse()
99+
.withHeader("Content-Type", MediaType.APPLICATION_JSON)
100+
.withStatus(200)
101+
.withBody("""
102+
{
103+
"results": [
104+
{
105+
"generated_token_count": 20,
106+
"input_token_count": 146,
107+
"stop_reason": "max_tokens",
108+
"seed": 40268626,
109+
"generated_text": "AI Response"
110+
}
111+
]
112+
}
113+
""")));
114+
115+
assertEquals("AI Response", service.chat("Hello"));
116+
}
117+
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
package com.ibm.generativeai.bam.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
7+
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
8+
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
import java.time.Duration;
12+
import java.util.List;
13+
14+
import jakarta.inject.Inject;
15+
import jakarta.ws.rs.core.MediaType;
16+
17+
import org.jboss.shrinkwrap.api.ShrinkWrap;
18+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
19+
import org.junit.jupiter.api.AfterAll;
20+
import org.junit.jupiter.api.BeforeAll;
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.extension.RegisterExtension;
23+
24+
import com.fasterxml.jackson.databind.ObjectMapper;
25+
import com.github.tomakehurst.wiremock.WireMockServer;
26+
27+
import io.quarkiverse.langchain4j.bam.BamChatModel;
28+
import io.quarkiverse.langchain4j.bam.BamRestApi;
29+
import io.quarkiverse.langchain4j.bam.Message;
30+
import io.quarkiverse.langchain4j.bam.Parameters;
31+
import io.quarkiverse.langchain4j.bam.TextGenerationRequest;
32+
import io.quarkiverse.langchain4j.bam.runtime.BamRecorder;
33+
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
34+
import io.quarkus.test.QuarkusUnitTest;
35+
36+
public class AllPropertiesTest {
37+
38+
static WireMockServer wireMockServer;
39+
static ObjectMapper mapper;
40+
41+
@RegisterExtension
42+
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
43+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", Util.URL)
44+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", Util.API_KEY)
45+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.timeout", "60s")
46+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.log-requests", "true")
47+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.log-responses", "true")
48+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.timeout", "60s")
49+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.version", "aaaa-mm-dd")
50+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.model-id", "my_super_model")
51+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.decoding-method", "greedy")
52+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.include-stop-sequence", "true")
53+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.max-new-tokens", "200")
54+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.min-new-tokens", "10")
55+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.random-seed", "2")
56+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.stop-sequences", "\n,\n\n")
57+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.temperature", "1.5")
58+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.time-limit", "1500")
59+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.top-k", "90")
60+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.top-p", "0.5")
61+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.typical-p", "0.5")
62+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.repetition-penalty", "2.0")
63+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.truncate-input-tokens", "0")
64+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.beam-width", "2")
65+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(Util.class));
66+
67+
@Inject
68+
Langchain4jBamConfig config;
69+
70+
@BeforeAll
71+
static void beforeAll() {
72+
wireMockServer = new WireMockServer(options().port(Util.PORT));
73+
wireMockServer.start();
74+
mapper = BamRestApi.objectMapper(new ObjectMapper());
75+
}
76+
77+
@AfterAll
78+
static void afterAll() {
79+
wireMockServer.stop();
80+
}
81+
82+
@Test
83+
void generate() throws Exception {
84+
85+
assertEquals(Util.URL, config.baseUrl().get().toString());
86+
assertEquals(Util.API_KEY, config.apiKey());
87+
assertEquals(Duration.ofSeconds(60), config.timeout());
88+
assertEquals(true, config.logRequests());
89+
assertEquals(true, config.logResponses());
90+
assertEquals("aaaa-mm-dd", config.version());
91+
assertEquals("my_super_model", config.chatModel().modelId());
92+
assertEquals("greedy", config.chatModel().decodingMethod());
93+
assertEquals(true, config.chatModel().includeStopSequence().get());
94+
assertEquals(200, config.chatModel().maxNewTokens());
95+
assertEquals(10, config.chatModel().minNewTokens());
96+
assertEquals(2, config.chatModel().randomSeed().get());
97+
assertEquals(List.of("\n", "\n\n"), config.chatModel().stopSequences().get());
98+
assertEquals(1.5, config.chatModel().temperature());
99+
assertEquals(1500, config.chatModel().timeLimit().get());
100+
assertEquals(90, config.chatModel().topK().get());
101+
assertEquals(0.5, config.chatModel().topP().get());
102+
assertEquals(0.5, config.chatModel().typicalP().get());
103+
assertEquals(2.0, config.chatModel().repetitionPenalty().get());
104+
assertEquals(0, config.chatModel().truncateInputTokens().get());
105+
assertEquals(2, config.chatModel().beamWidth().get());
106+
107+
var modelId = config.chatModel().modelId();
108+
109+
var parameters = Parameters.builder()
110+
.minNewTokens(10)
111+
.maxNewTokens(200)
112+
.decodingMethod("greedy")
113+
.includeStopSequence(true)
114+
.randomSeed(2)
115+
.stopSequences(List.of("\n", "\n\n"))
116+
.temperature(1.5)
117+
.timeLimit(1500)
118+
.topK(90)
119+
.topP(0.5)
120+
.typicalP(0.5)
121+
.repetitionPenalty(2.0)
122+
.truncateInputTokens(0)
123+
.beamWidth(2)
124+
.build();
125+
126+
List<Message> messages = List.of(
127+
new Message("user", "Hello how are you?"));
128+
129+
var body = new TextGenerationRequest(modelId, messages, parameters);
130+
131+
wireMockServer.stubFor(
132+
post(urlEqualTo(Util.URL_CHAT_API.formatted(config.version())))
133+
.withHeader("Authorization", equalTo("Bearer %s".formatted(Util.API_KEY)))
134+
.withRequestBody(equalToJson(mapper.writeValueAsString(body)))
135+
.willReturn(
136+
aResponse()
137+
.withHeader("Content-Type", MediaType.APPLICATION_JSON)
138+
.withBody(
139+
"""
140+
{
141+
"id": "05a245ad-1da7-4b9d-9807-ae1733177c1d",
142+
"model_id": "meta-llama/llama-2-70b-chat",
143+
"created_at": "2023-09-01T09:28:29.378Z",
144+
"results": [
145+
{
146+
"generated_token_count": 20,
147+
"input_token_count": 146,
148+
"stop_reason": "max_tokens",
149+
"seed": 40268626,
150+
"generated_text": "Hello! I'm doing well, thanks for asking. I'm here to assist you"
151+
}
152+
],
153+
"conversation_id": "cd3a9bca-b88e-41e4-9d62-bab33098fe39"
154+
}
155+
""")));
156+
157+
BamRecorder recorder = new BamRecorder();
158+
BamChatModel model = (BamChatModel) recorder.chatModel(config).get();
159+
assertEquals("Hello! I'm doing well, thanks for asking. I'm here to assist you", model.generate("Hello how are you?"));
160+
}
161+
}

0 commit comments

Comments
 (0)