|
| 1 | +package io.quarkiverse.langchain4j.bam.deployment; |
| 2 | + |
| 3 | +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; |
| 4 | +import static org.junit.jupiter.api.Assertions.assertEquals; |
| 5 | + |
| 6 | +import java.time.Duration; |
| 7 | +import java.util.List; |
| 8 | + |
| 9 | +import jakarta.inject.Inject; |
| 10 | + |
| 11 | +import org.jboss.shrinkwrap.api.ShrinkWrap; |
| 12 | +import org.jboss.shrinkwrap.api.spec.JavaArchive; |
| 13 | +import org.junit.jupiter.api.AfterAll; |
| 14 | +import org.junit.jupiter.api.BeforeAll; |
| 15 | +import org.junit.jupiter.api.Test; |
| 16 | +import org.junit.jupiter.api.extension.RegisterExtension; |
| 17 | + |
| 18 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 19 | +import com.github.tomakehurst.wiremock.WireMockServer; |
| 20 | + |
| 21 | +import dev.langchain4j.model.chat.ChatLanguageModel; |
| 22 | +import io.quarkiverse.langchain4j.bam.BamRestApi; |
| 23 | +import io.quarkiverse.langchain4j.bam.Message; |
| 24 | +import io.quarkiverse.langchain4j.bam.Parameters; |
| 25 | +import io.quarkiverse.langchain4j.bam.TextGenerationRequest; |
| 26 | +import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig; |
| 27 | +import io.quarkus.test.QuarkusUnitTest; |
| 28 | + |
| 29 | +public class AllPropertiesTest { |
| 30 | + |
| 31 | + static WireMockServer wireMockServer; |
| 32 | + static ObjectMapper mapper; |
| 33 | + static WireMockUtil mockServers; |
| 34 | + |
| 35 | + @RegisterExtension |
| 36 | + static QuarkusUnitTest unitTest = new QuarkusUnitTest() |
| 37 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", WireMockUtil.URL) |
| 38 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", WireMockUtil.API_KEY) |
| 39 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.timeout", "60s") |
| 40 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.log-requests", "true") |
| 41 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.log-responses", "true") |
| 42 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.timeout", "60s") |
| 43 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.version", "aaaa-mm-dd") |
| 44 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.model-id", "my_super_model") |
| 45 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.decoding-method", "greedy") |
| 46 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.include-stop-sequence", "true") |
| 47 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.max-new-tokens", "200") |
| 48 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.min-new-tokens", "10") |
| 49 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.random-seed", "2") |
| 50 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.stop-sequences", "\n,\n\n") |
| 51 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.temperature", "1.5") |
| 52 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.time-limit", "1500") |
| 53 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.top-k", "90") |
| 54 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.top-p", "0.5") |
| 55 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.typical-p", "0.5") |
| 56 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.repetition-penalty", "2.0") |
| 57 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.truncate-input-tokens", "0") |
| 58 | + .overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.beam-width", "2") |
| 59 | + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); |
| 60 | + |
| 61 | + @Inject |
| 62 | + Langchain4jBamConfig config; |
| 63 | + |
| 64 | + @Inject |
| 65 | + ChatLanguageModel model; |
| 66 | + |
| 67 | + @BeforeAll |
| 68 | + static void beforeAll() { |
| 69 | + wireMockServer = new WireMockServer(options().port(WireMockUtil.PORT)); |
| 70 | + wireMockServer.start(); |
| 71 | + mapper = BamRestApi.objectMapper(new ObjectMapper()); |
| 72 | + mockServers = new WireMockUtil(wireMockServer); |
| 73 | + } |
| 74 | + |
| 75 | + @AfterAll |
| 76 | + static void afterAll() { |
| 77 | + wireMockServer.stop(); |
| 78 | + } |
| 79 | + |
| 80 | + @Test |
| 81 | + void generate() throws Exception { |
| 82 | + |
| 83 | + assertEquals(WireMockUtil.URL, config.baseUrl().get().toString()); |
| 84 | + assertEquals(WireMockUtil.API_KEY, config.apiKey()); |
| 85 | + assertEquals(Duration.ofSeconds(60), config.timeout()); |
| 86 | + assertEquals(true, config.logRequests()); |
| 87 | + assertEquals(true, config.logResponses()); |
| 88 | + assertEquals("aaaa-mm-dd", config.version()); |
| 89 | + assertEquals("my_super_model", config.chatModel().modelId()); |
| 90 | + assertEquals("greedy", config.chatModel().decodingMethod()); |
| 91 | + assertEquals(true, config.chatModel().includeStopSequence().get()); |
| 92 | + assertEquals(200, config.chatModel().maxNewTokens()); |
| 93 | + assertEquals(10, config.chatModel().minNewTokens()); |
| 94 | + assertEquals(2, config.chatModel().randomSeed().get()); |
| 95 | + assertEquals(List.of("\n", "\n\n"), config.chatModel().stopSequences().get()); |
| 96 | + assertEquals(1.5, config.chatModel().temperature()); |
| 97 | + assertEquals(1500, config.chatModel().timeLimit().get()); |
| 98 | + assertEquals(90, config.chatModel().topK().get()); |
| 99 | + assertEquals(0.5, config.chatModel().topP().get()); |
| 100 | + assertEquals(0.5, config.chatModel().typicalP().get()); |
| 101 | + assertEquals(2.0, config.chatModel().repetitionPenalty().get()); |
| 102 | + assertEquals(0, config.chatModel().truncateInputTokens().get()); |
| 103 | + assertEquals(2, config.chatModel().beamWidth().get()); |
| 104 | + |
| 105 | + var modelId = config.chatModel().modelId(); |
| 106 | + |
| 107 | + var parameters = Parameters.builder() |
| 108 | + .minNewTokens(10) |
| 109 | + .maxNewTokens(200) |
| 110 | + .decodingMethod("greedy") |
| 111 | + .includeStopSequence(true) |
| 112 | + .randomSeed(2) |
| 113 | + .stopSequences(List.of("\n", "\n\n")) |
| 114 | + .temperature(1.5) |
| 115 | + .timeLimit(1500) |
| 116 | + .topK(90) |
| 117 | + .topP(0.5) |
| 118 | + .typicalP(0.5) |
| 119 | + .repetitionPenalty(2.0) |
| 120 | + .truncateInputTokens(0) |
| 121 | + .beamWidth(2) |
| 122 | + .build(); |
| 123 | + |
| 124 | + List<Message> messages = List.of( |
| 125 | + new Message("user", "Hello how are you?")); |
| 126 | + |
| 127 | + var body = new TextGenerationRequest(modelId, messages, parameters); |
| 128 | + |
| 129 | + mockServers.mockBuilder(200, config.version()) |
| 130 | + .body(mapper.writeValueAsString(body)) |
| 131 | + .response(""" |
| 132 | + { |
| 133 | + "id": "05a245ad-1da7-4b9d-9807-ae1733177c1d", |
| 134 | + "model_id": "meta-llama/llama-2-70b-chat", |
| 135 | + "created_at": "2023-09-01T09:28:29.378Z", |
| 136 | + "results": [ |
| 137 | + { |
| 138 | + "generated_token_count": 20, |
| 139 | + "input_token_count": 146, |
| 140 | + "stop_reason": "max_tokens", |
| 141 | + "seed": 40268626, |
| 142 | + "generated_text": "Hello! I'm doing well, thanks for asking. I'm here to assist you" |
| 143 | + } |
| 144 | + ], |
| 145 | + "conversation_id": "cd3a9bca-b88e-41e4-9d62-bab33098fe39" |
| 146 | + } |
| 147 | + """) |
| 148 | + .build(); |
| 149 | + |
| 150 | + assertEquals("Hello! I'm doing well, thanks for asking. I'm here to assist you", model.generate("Hello how are you?")); |
| 151 | + } |
| 152 | +} |
0 commit comments