diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java index c3d30221efe0..71a62e15636f 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java @@ -50,7 +50,7 @@ public class AiCommonConfig { /** * temperature. */ - private Double temperature = 0.8; + private Double temperature; /** * max tokens. diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java new file mode 100644 index 000000000000..73244c838f82 --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.common.protocol; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.shenyu.common.exception.ShenyuException; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; + +import java.util.Objects; + +/** + * Adapts between OpenAI Chat Completions wire format and internal representations. + * + *

Note: deserialization into Spring AI's {@link ChatCompletionRequest} preserves fields + * modeled by that class (messages, model, temperature, maxTokens, stream, tools, etc.). + * Provider-specific extension fields not modeled by {@code ChatCompletionRequest} are dropped + * during deserialization. For the OpenAI-compatible providers this plugin currently supports, + * all required fields are covered. + */ +public final class OpenAiProtocolAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiProtocolAdapter.class); + + private static final com.fasterxml.jackson.databind.ObjectMapper MAPPER = + new com.fasterxml.jackson.databind.ObjectMapper(); + + /** + * OpenAI API field name constants. + */ + private static final String FIELD_MODEL = "model"; + + private static final String FIELD_TEMPERATURE = "temperature"; + + private static final String FIELD_STREAM = "stream"; + + private static final String FIELD_MAX_COMPLETION_TOKENS = "max_completion_tokens"; + + private static final String FIELD_MAX_TOKENS = "max_tokens"; + + private OpenAiProtocolAdapter() { + } + + /** + * Resolve the stream flag: client request body takes priority, + * falls back to the provided default value. + * + * @param requestBody the raw JSON request body + * @param fallbackStream the default stream value from admin config + * @return true if streaming, false otherwise + */ + public static boolean resolveStream(final String requestBody, final Boolean fallbackStream) { + if (Objects.isNull(requestBody) || requestBody.isEmpty()) { + return Boolean.TRUE.equals(fallbackStream); + } + final JsonNode root = parseStrict(requestBody); + if (Objects.isNull(root)) { + return Boolean.TRUE.equals(fallbackStream); + } + if (root.hasNonNull(FIELD_STREAM)) { + return root.get(FIELD_STREAM).asBoolean(); + } + return Boolean.TRUE.equals(fallbackStream); + } + + /** + * Parse raw request body directly into ChatCompletionRequest, preserving fields + * modeled by Spring AI (including reasoning_content in assistant messages). + * + *

Spring AI's createRequest() loses reasoning_content, refusal, and annotations + * when reconstructing ChatCompletionMessage from AssistantMessage. + * This method avoids that loss by deserializing the raw JSON directly. + * + *

For max_tokens and max_completion_tokens: client values are preserved as-is; + * only when the client omits both fields does the fallbackConfig value populate both, + * ensuring compatibility with providers that require either field. + * + * @param requestBody the raw JSON request body in OpenAI Chat Completions format + * @param stream whether this is a streaming request (sets the stream field) + * @return a ChatCompletionRequest with modeled fields preserved from the original request + */ + public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, final boolean stream) { + return toChatCompletionRequest(requestBody, stream, null); + } + + /** + * Parse raw request body directly into ChatCompletionRequest, preserving modeled fields. + * When fallbackConfig is provided, its non-null fields (model, temperature, maxTokens) + * override the client request values, matching the original ChatModel-based fallback behavior + * where the fallback ChatModel's config takes precedence. + * + * @param requestBody the raw JSON request body in OpenAI Chat Completions format + * @param stream whether this is a streaming request + * @param fallbackConfig the fallback config whose non-null fields override client values + * @return a ChatCompletionRequest with modeled fields preserved + */ + public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, + final boolean stream, final AiCommonConfig fallbackConfig) { + if (Objects.isNull(requestBody) || requestBody.isEmpty()) { + throw new ShenyuException("Request body must not be empty"); + } + final JsonNode root = parseStrict(requestBody); + if (Objects.isNull(root) || !root.isObject()) { + throw new ShenyuException("Invalid request body: expected a JSON object"); + } + final ObjectNode mutableRoot = (ObjectNode) root; + + if (Objects.nonNull(fallbackConfig)) { + if (Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) { + mutableRoot.put(FIELD_MODEL, fallbackConfig.getModel()); + } + if (Objects.nonNull(fallbackConfig.getTemperature())) { + mutableRoot.put(FIELD_TEMPERATURE, fallbackConfig.getTemperature()); + } + if (Objects.nonNull(fallbackConfig.getMaxTokens())) { + mutableRoot.put(FIELD_MAX_TOKENS, fallbackConfig.getMaxTokens()); + mutableRoot.put(FIELD_MAX_COMPLETION_TOKENS, fallbackConfig.getMaxTokens()); + } + } + + mutableRoot.put(FIELD_STREAM, stream); + + try { + return MAPPER.treeToValue(mutableRoot, ChatCompletionRequest.class); + } catch (Exception e) { + LOG.error("[AiProxy] Failed to parse request body into ChatCompletionRequest", e); + throw new ShenyuException("Failed to parse request body into ChatCompletionRequest", e); + } + } + + private static JsonNode parseStrict(final String json) { + try { + return MAPPER.readTree(json); + } catch (JsonProcessingException e) { + return null; + } + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java deleted file mode 100644 index 08c68c4f46ec..000000000000 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.shenyu.plugin.ai.common.strategy; - -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatResponse; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * Fallback strategy. - */ -public interface FallbackStrategy { - - /** - * Execute the fallback strategy. - * - * @param fallbackClient the pre-configured and cached chat client for fallback - * @param requestBody the original request body as a string - * @param originalError the original error that triggered the fallback - * @return a Mono containing the fallback ChatResponse - */ - Mono fallback(ChatClient fallbackClient, String requestBody, Throwable originalError); - - /** - * Execute the fallback strategy for stream. - * - * @param fallbackClient the pre-configured and cached chat client for fallback - * @param requestBody the original request body as a string - * @param originalError the original error that triggered the fallback - * @return a Flux containing the fallback ChatResponse - */ - Flux fallbackStream(ChatClient fallbackClient, String requestBody, Throwable originalError); -} \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java deleted file mode 100644 index 94b3653e654d..000000000000 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.shenyu.plugin.ai.common.strategy; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatResponse; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - -/** - * A fallback strategy that simply executes a call with a pre-configured client. - */ -public final class SimpleModelFallbackStrategy implements FallbackStrategy { - - /** - * The constant INSTANCE. - */ - public static final FallbackStrategy INSTANCE = new SimpleModelFallbackStrategy(); - - private static final Logger LOG = LoggerFactory.getLogger(SimpleModelFallbackStrategy.class); - - private SimpleModelFallbackStrategy() { - } - - @Override - public Mono fallback(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) { - LOG.warn("Executing simple model fallback strategy due to error: {}", originalError.getMessage()); - - return Mono.fromCallable(() -> fallbackClient.prompt() - .user(requestBody) - .call() - .chatResponse()) - .subscribeOn(Schedulers.boundedElastic()) - .doOnSuccess(response -> LOG.info("Fallback call successful.")) - .doOnError(fallbackError -> LOG.error("Fallback call also failed.", fallbackError)); - } - - @Override - public Flux fallbackStream(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) { - LOG.warn("Executing simple model fallback stream strategy due to error: {}", originalError.getMessage()); - - return Flux.defer(() -> fallbackClient.prompt() - .user(requestBody) - .stream() - .chatResponse()) - .subscribeOn(Schedulers.boundedElastic()) - .doOnComplete(() -> LOG.info("Fallback stream completed successfully.")) - .doOnError(fallbackError -> LOG.error("Fallback stream also failed.", fallbackError)); - } -} \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java new file mode 100644 index 000000000000..e843d95f1d0b --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.common.protocol; + +import org.apache.shenyu.common.exception.ShenyuException; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public final class OpenAiProtocolAdapterTest { + + private static final String BASE_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}"; + + @Test + void testNullRequestBody() { + assertThrows(ShenyuException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest(null, false)); + } + + @Test + void testEmptyRequestBody() { + assertThrows(ShenyuException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest("", false)); + } + + @Test + void testStreamFlagSet() { + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true); + assertNotNull(req); + assertTrue(req.stream()); + } + + @Test + void testResolveStreamClientTrueOverridesFallbackFalse() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":true}"; + assertTrue(OpenAiProtocolAdapter.resolveStream(body, false)); + } + + @Test + void testResolveStreamClientFalseOverridesFallbackTrue() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":false}"; + assertFalse(OpenAiProtocolAdapter.resolveStream(body, true)); + } + + @Test + void testResolveStreamFallbackWhenClientMissing() { + assertTrue(OpenAiProtocolAdapter.resolveStream(BASE_BODY, true)); + assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, false)); + } + + @Test + void testResolveStreamFallbackWhenClientMissingAndConfigNull() { + assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, null)); + } + + @Test + void testResolveStreamFallbackWhenBodyEmpty() { + assertTrue(OpenAiProtocolAdapter.resolveStream("", true)); + assertFalse(OpenAiProtocolAdapter.resolveStream("", false)); + } + + @Test + void testStreamFlagUnset() { + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false); + assertNotNull(req); + assertFalse(req.stream()); + } + + @Test + void testMaxCompletionTokensPreservedAsIs() { + final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":100}"; + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false); + assertNotNull(req); + assertEquals(100, req.maxCompletionTokens()); + } + + @Test + void testFallbackModelOverridesClientModel() { + final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("fallback-model"); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals("fallback-model", req.model()); + } + + @Test + void testFallbackModelWhenClientMissing() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("admin-model"); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertEquals("admin-model", req.model()); + } + + @Test + void testNoFallbackModelWhenConfigIsNull() { + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, null); + assertNotNull(req); + } + + @Test + void testNoFallbackWhenConfigModelIsNull() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel(null); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testNoFallbackWhenConfigModelIsEmpty() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel(""); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testFallbackTemperatureOverridesClient() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.5}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setTemperature(0.9); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals(0.9, req.temperature(), 0.001); + } + + @Test + void testFallbackTemperatureWhenClientMissing() { + final AiCommonConfig config = new AiCommonConfig(); + config.setTemperature(0.7); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertEquals(0.7, req.temperature(), 0.001); + } + + @Test + void testNoFallbackWhenConfigTemperatureIsNull() { + final AiCommonConfig config = new AiCommonConfig(); + config.setTemperature(null); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testFallbackMaxTokensOverridesClient() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_tokens\":200}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setMaxTokens(500); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals(500, req.maxTokens()); + } + + @Test + void testFallbackMaxTokensWhenClientMissing() { + final AiCommonConfig config = new AiCommonConfig(); + config.setMaxTokens(500); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertEquals(500, req.maxTokens()); + } + + @Test + void testNoFallbackMaxTokensWhenConfigNull() { + final AiCommonConfig config = new AiCommonConfig(); + config.setMaxTokens(null); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testAllFallbackFieldsAppliedWhenClientMissingAll() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("admin-model"); + config.setTemperature(0.3); + config.setMaxTokens(1024); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true, config); + assertEquals("admin-model", req.model()); + assertEquals(0.3, req.temperature(), 0.001); + assertEquals(1024, req.maxTokens()); + assertTrue(req.stream()); + } + + @Test + void testAllFallbackFieldsOverrideClient() { + final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]," + + "\"temperature\":0.1,\"max_tokens\":50}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("fallback-model"); + config.setTemperature(0.9); + config.setMaxTokens(9999); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals("fallback-model", req.model()); + assertEquals(0.9, req.temperature(), 0.001); + assertEquals(9999, req.maxTokens()); + } + + @Test + void testPartialFallbackOverrideWithPartialClient() { + final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.1}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("fallback-model"); + config.setTemperature(0.9); + config.setMaxTokens(2048); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals("fallback-model", req.model()); + assertEquals(0.9, req.temperature(), 0.001); + assertEquals(2048, req.maxTokens()); + } + + @Test + void testMalformedJsonThrows() { + assertThrows(ShenyuException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest("not valid json{{{", false)); + } + + @Test + void testMaxCompletionTokensNonNumberThrows() { + final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":\"abc\"}"; + assertThrows(ShenyuException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest(body, false)); + } + + @Test + void testResolveStreamMalformedJsonFallsBack() { + assertFalse(OpenAiProtocolAdapter.resolveStream("not json{{{", false)); + assertTrue(OpenAiProtocolAdapter.resolveStream("not json{{{", true)); + } + + @Test + void testAnnotationsRoundTrip() throws Exception { + final com.fasterxml.jackson.databind.ObjectMapper mapper = new com.fasterxml.jackson.databind.ObjectMapper(); + final String body = "{\"messages\":[" + + "{\"role\":\"user\",\"content\":\"hello\"}," + + "{\"role\":\"assistant\",\"content\":\"see link\"," + + "\"annotations\":[{\"type\":\"url_citation\"," + + "\"url_citation\":{\"url\":\"https://example.com\",\"title\":\"Example\",\"start_index\":0,\"end_index\":8}}]}" + + "]}"; + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false); + assertNotNull(req); + + final String serialized = mapper.writeValueAsString(req); + final com.fasterxml.jackson.databind.JsonNode root = mapper.readTree(serialized); + final com.fasterxml.jackson.databind.JsonNode assistantMsg = root.get("messages").get(1); + assertTrue(assistantMsg.has("annotations")); + final com.fasterxml.jackson.databind.JsonNode annotations = assistantMsg.get("annotations"); + assertEquals(1, annotations.size()); + assertEquals("url_citation", annotations.get(0).get("type").asText()); + assertEquals("https://example.com", annotations.get(0).get("url_citation").get("url").asText()); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java index 687c51233f95..916431437acd 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java @@ -21,25 +21,26 @@ import org.apache.shenyu.common.dto.RuleData; import org.apache.shenyu.common.dto.SelectorData; import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; -import org.apache.shenyu.common.enums.AiModelProviderEnum; import org.apache.shenyu.common.enums.PluginEnum; import org.apache.shenyu.common.utils.JsonUtils; import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; -import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry; +import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache; import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; +import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService.FallbackContext; +import org.apache.shenyu.plugin.ai.proxy.enhanced.service.UpstreamErrorLogger; import org.apache.shenyu.plugin.api.ShenyuPluginChain; import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils; import org.apache.shenyu.plugin.base.AbstractShenyuPlugin; import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.http.HttpHeaders; @@ -65,26 +66,18 @@ public class AiProxyPlugin extends AbstractShenyuPlugin { */ private static final long MAX_REQUEST_BODY_SIZE_BYTES = 5 * 1024 * 1024L; - private final AiModelFactoryRegistry aiModelFactoryRegistry; - private final AiProxyConfigService aiProxyConfigService; private final AiProxyExecutorService aiProxyExecutorService; - private final ChatClientCache chatClientCache; - private final AiProxyPluginHandler aiProxyPluginHandler; public AiProxyPlugin( - final AiModelFactoryRegistry aiModelFactoryRegistry, final AiProxyConfigService aiProxyConfigService, final AiProxyExecutorService aiProxyExecutorService, - final ChatClientCache chatClientCache, final AiProxyPluginHandler aiProxyPluginHandler) { - this.aiModelFactoryRegistry = aiModelFactoryRegistry; this.aiProxyConfigService = aiProxyConfigService; this.aiProxyExecutorService = aiProxyExecutorService; - this.chatClientCache = chatClientCache; this.aiProxyPluginHandler = aiProxyPluginHandler; } @@ -148,7 +141,8 @@ protected Mono doExecute( } } - if (Boolean.TRUE.equals(primaryConfig.getStream())) { + final boolean stream = OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream()); + if (stream) { return handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); } return handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); @@ -161,23 +155,26 @@ private Mono handleStreamRequest( final String requestBody, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle) { - final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig); - final String prompt = aiProxyConfigService.extractPrompt(requestBody); - final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle, - selector.getId(), requestBody); + final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig); + final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig); + final Optional fallbackCtx = resolveFallbackContext( + selector.getId(), primaryConfig, selectorHandle, requestBody); final ServerHttpResponse response = exchange.getResponse(); response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM); - final Flux chatResponseFlux = aiProxyExecutorService.executeStream(mainClient, fallbackClient, - prompt); + final Flux chunkFlux = aiProxyExecutorService.executeDirectStream( + mainApi, fallbackCtx, request, requestBody, true); - final Flux sseFlux = chatResponseFlux.map( - chatResponse -> { - final String json = JsonUtils.toJson(chatResponse); + final Flux sseFlux = chunkFlux.map( + chunk -> { + final String json = JsonUtils.toJson(chunk); final String sseData = "data: " + json + "\n\n"; return response.bufferFactory() .wrap(sseData.getBytes(StandardCharsets.UTF_8)); - }); + }).concatWith(Mono.fromSupplier(() -> + response.bufferFactory() + .wrap("data: [DONE]\n\n".getBytes(StandardCharsets.UTF_8)))) + .doOnError(e -> logUpstreamError(e, "stream")); return response.writeWith(sseFlux); } @@ -188,24 +185,26 @@ private Mono handleNonStreamRequest( final String requestBody, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle) { - final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig); - final String prompt = aiProxyConfigService.extractPrompt(requestBody); - final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle, - selector.getId(), requestBody); + final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig); + final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig); + final Optional fallbackCtx = resolveFallbackContext( + selector.getId(), primaryConfig, selectorHandle, requestBody); return aiProxyExecutorService - .execute(mainClient, fallbackClient, prompt) + .executeDirectCall(mainApi, fallbackCtx, request, requestBody) .flatMap( - response -> { - byte[] jsonBytes = JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8); + responseEntity -> { + final String responseJson = JsonUtils.toJson(responseEntity.getBody()); + byte[] jsonBytes = responseJson.getBytes(StandardCharsets.UTF_8); return WebFluxResultUtils.result(exchange, jsonBytes); - }); + }) + .doOnError(e -> logUpstreamError(e, "non-stream")); } - private Optional resolveFallbackClient( + private Optional resolveFallbackContext( + final String selectorId, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle, - final String selectorId, final String requestBody) { return aiProxyConfigService .resolveDynamicFallbackConfig(primaryConfig, requestBody) @@ -214,86 +213,47 @@ private Optional resolveFallbackClient( if (LOG.isDebugEnabled()) { LOG.debug("[AiProxy] dynamic fallback config: {}", cfg); } - return createDynamicFallbackClient(cfg); + return new FallbackContext(createOpenAiApi(cfg), cfg); }) - .or( - () -> aiProxyConfigService - .resolveAdminFallbackConfig(primaryConfig, selectorHandle) - .map(adminFallbackConfig -> { - LOG.info("[AiProxy] use admin fallback"); - if (LOG.isDebugEnabled()) { - LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig); - } - return createAdminFallbackClient(selectorId, adminFallbackConfig); - })); + .or(() -> aiProxyConfigService + .resolveAdminFallbackConfig(primaryConfig, selectorHandle) + .map(adminFallbackConfig -> { + LOG.info("[AiProxy] use admin fallback"); + if (LOG.isDebugEnabled()) { + LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig); + } + return new FallbackContext( + getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig), + adminFallbackConfig); + })); + } + + private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type, final AiCommonConfig config) { + final String cacheKey = selectorId + "|" + type + "_" + generateConfigCacheKey(config); + return OpenAiApiCache.getInstance().computeIfAbsent(cacheKey, () -> createOpenAiApi(config)); } - /** - * Generate cache key based on config fields excluding apiKey. - * This ensures cache consistency even when apiKey is updated at runtime. - * - * @param config the config - * @return cache key hash - */ private int generateConfigCacheKey(final AiCommonConfig config) { return Objects.hash( - config.getProvider(), config.getBaseUrl(), + config.getApiKey(), config.getModel(), config.getTemperature(), - config.getMaxTokens(), - config.getStream() - // Explicitly exclude apiKey to avoid cache misses when apiKey changes + config.getMaxTokens() ); } - private ChatClient createMainChatClient(final String selectorId, final AiCommonConfig config) { - final int configHash = generateConfigCacheKey(config); - final String cacheKey = selectorId + "|main_" + configHash; - return chatClientCache.computeIfAbsent( - cacheKey, - () -> { - LOG.info("Creating and caching main model for selector: {}, key: {}", selectorId, cacheKey); - return createChatModel(config); - }); - } - - private ChatClient createAdminFallbackClient( - final String selectorId, final AiCommonConfig fallbackConfig) { - final int configHash = generateConfigCacheKey(fallbackConfig); - final String fallbackCacheKey = selectorId + "|adminFallback_" + configHash; - return chatClientCache.computeIfAbsent( - fallbackCacheKey, - () -> { - LOG.info( - "Creating and caching admin fallback model for selector: {}, key: {}", - selectorId, fallbackCacheKey); - return createChatModel(fallbackConfig); - }); - } - - private ChatClient createDynamicFallbackClient(final AiCommonConfig fallbackConfig) { - LOG.info("Creating non-cached dynamic fallback model."); - return ChatClient.builder(createChatModel(fallbackConfig)).build(); - } - - private ChatModel createChatModel(final AiCommonConfig config) { - if (LOG.isDebugEnabled()) { - LOG.debug("Creating chat model with config: {}", config); - } - final AiModelProviderEnum provider = AiModelProviderEnum.getByName(config.getProvider()); - if (Objects.isNull(provider)) { - throw new IllegalArgumentException( - "Invalid AI model provider in config: " + config.getProvider()); + private OpenAiApi createOpenAiApi(final AiCommonConfig config) { + if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) { + throw new IllegalArgumentException("baseUrl must not be empty"); } - final var factory = aiModelFactoryRegistry.getFactory(provider); - if (Objects.isNull(factory)) { - throw new IllegalArgumentException( - "AI model factory not found for provider: " + provider.getName()); + if (Objects.isNull(config.getApiKey()) || config.getApiKey().isEmpty()) { + throw new IllegalArgumentException("apiKey must not be empty"); } - return Objects.requireNonNull( - factory.createAiModel(config), - "The AI model created by the factory must not be null"); + return OpenAiApi.builder() + .baseUrl(config.getBaseUrl()) + .apiKey(config.getApiKey()) + .build(); } @Override @@ -305,4 +265,8 @@ public int getOrder() { public String named() { return PluginEnum.AI_PROXY.getName(); } + + private void logUpstreamError(final Throwable e, final String mode) { + UpstreamErrorLogger.logUpstreamError(LOG, e, mode); + } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java deleted file mode 100644 index 473ace98fa91..000000000000 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.shenyu.plugin.ai.proxy.enhanced.cache; - -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; - -/** - * This is ChatClient cache. - */ -public final class ChatClientCache { - - private static final Logger LOG = LoggerFactory.getLogger(ChatClientCache.class); - - private static final int MAX_CACHE_SIZE = getCacheSize(); - - private final Map chatClientMap = new ConcurrentHashMap<>(); - - private final AtomicBoolean evictionInProgress = new AtomicBoolean(false); - - /** - * Instantiates a new Chat client cache. - */ - public ChatClientCache() { - } - - private static int getCacheSize() { - String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize", - System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE")); - if (Objects.nonNull(value)) { - try { - return Integer.parseInt(value); - } catch (NumberFormatException e) { - LoggerFactory.getLogger(ChatClientCache.class) - .warn("[ChatClientCache] Invalid cache size '{}', using default 500.", value); - } - } - return 500; - } - - /** - * Gets client or compute if absent. - * - * @param key the key - * @param chatModelSupplier the chat model supplier - * @return the chat client - */ - public ChatClient computeIfAbsent(final String key, final Supplier chatModelSupplier) { - // Check size before computing, but use synchronized block to prevent race conditions - final int currentSize = chatClientMap.size(); - if (currentSize > MAX_CACHE_SIZE) { - // Use atomic flag to ensure only one thread performs eviction - if (evictionInProgress.compareAndSet(false, true)) { - try { - synchronized (chatClientMap) { - // Double-check after acquiring lock - if (chatClientMap.size() > MAX_CACHE_SIZE) { - evictOldestEntries(); - } - } - } finally { - evictionInProgress.set(false); - } - } - } - return chatClientMap.computeIfAbsent(key, k -> ChatClient.builder(chatModelSupplier.get()).build()); - } - - /** - * Evict oldest entries when cache size exceeds limit. - * Removes approximately 25% of entries to avoid thundering herd problem. - */ - private void evictOldestEntries() { - final int currentSize = chatClientMap.size(); - if (currentSize <= MAX_CACHE_SIZE) { - return; - } - - // Evict 25% of entries, but at least 10 entries - final int evictCount = Math.max(10, currentSize / 4); - LOG.warn("[ChatClientCache] Cache size {} exceeded limit {}, evicting {} oldest entries", - currentSize, MAX_CACHE_SIZE, evictCount); - - // Since ConcurrentHashMap doesn't maintain insertion order, - // we evict entries based on iteration order (which is somewhat arbitrary but better than clearing all) - int removed = 0; - for (final String key : chatClientMap.keySet()) { - if (removed >= evictCount) { - break; - } - chatClientMap.remove(key); - removed++; - } - - LOG.info("[ChatClientCache] Evicted {} entries, cache size now: {}", removed, chatClientMap.size()); - } - - /** - * Removes all cached clients associated with a selector ID (by prefix matching - * "selectorId|"). - * - * @param selectorId the selector id - */ - public void remove(final String selectorId) { - if (java.util.Objects.isNull(selectorId)) { - return; - } - final String prefix = selectorId + "|"; - chatClientMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix)); - LOG.info("[ChatClientCache] invalidate selectorId={} (by prefix)", selectorId); - } - - /** - * Clear all cached clients. - */ - public void clearAll() { - chatClientMap.clear(); - LOG.info("[ChatClientCache] cleared all cached clients"); - } -} \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java new file mode 100644 index 000000000000..d75aab480f32 --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.cache; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.openai.api.OpenAiApi; + +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +/** + * This is OpenAiApi cache. + * Caches OpenAiApi instances by config key to avoid recreating on every request. + */ +public final class OpenAiApiCache { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiApiCache.class); + + private static final OpenAiApiCache INSTANCE = new OpenAiApiCache(); + + private static final int MAX_CACHE_SIZE = getCacheSize(); + + private final Map openAiApiMap = new ConcurrentHashMap<>(); + + private final AtomicBoolean evictionInProgress = new AtomicBoolean(false); + + /** + * Instantiates a new OpenAiApi cache. + */ + public OpenAiApiCache() { + } + + /** + * Gets instance. + * + * @return singleton + */ + public static OpenAiApiCache getInstance() { + return INSTANCE; + } + + private static int getCacheSize() { + String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize", + System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE")); + if (Objects.nonNull(value)) { + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + LOG.warn("[OpenAiApiCache] Invalid cache size '{}', using default 500.", value); + } + } + return 500; + } + + /** + * Gets OpenAiApi or compute if absent. + * + * @param key the cache key (typically selectorId|configHash) + * @param openAiApiSupplier the supplier to create OpenAiApi if absent + * @return the cached or newly created OpenAiApi + */ + public OpenAiApi computeIfAbsent(final String key, final Supplier openAiApiSupplier) { + final int currentSize = openAiApiMap.size(); + if (currentSize > MAX_CACHE_SIZE) { + if (evictionInProgress.compareAndSet(false, true)) { + try { + synchronized (openAiApiMap) { + if (openAiApiMap.size() > MAX_CACHE_SIZE) { + evictEntries(); + } + } + } finally { + evictionInProgress.set(false); + } + } + } + return openAiApiMap.computeIfAbsent(key, k -> openAiApiSupplier.get()); + } + + /** + * Evict arbitrary entries when cache size exceeds limit. + * Note: ConcurrentHashMap iteration order is undefined, so evicted entries are arbitrary, + * not guaranteed to be the oldest. Removes approximately 25% of entries to avoid + * thundering herd problem. + */ + private void evictEntries() { + final int currentSize = openAiApiMap.size(); + if (currentSize <= MAX_CACHE_SIZE) { + return; + } + + final int evictCount = Math.max(10, currentSize / 4); + LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} arbitrary entries", + currentSize, MAX_CACHE_SIZE, evictCount); + + int removed = 0; + for (final String key : openAiApiMap.keySet()) { + if (removed >= evictCount) { + break; + } + openAiApiMap.remove(key); + removed++; + } + + LOG.info("[OpenAiApiCache] Evicted {} entries, cache size now: {}", removed, openAiApiMap.size()); + } + + /** + * Removes all cached OpenAiApi instances associated with a selector ID + * (by prefix matching "selectorId|"). + * + * @param selectorId the selector id + */ + public void remove(final String selectorId) { + if (Objects.isNull(selectorId)) { + return; + } + final String prefix = selectorId + "|"; + openAiApiMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix)); + LOG.info("[OpenAiApiCache] invalidate selectorId={} (by prefix)", selectorId); + } + + /** + * Clear all cached OpenAiApi instances. + */ + public void clearAll() { + openAiApiMap.clear(); + LOG.info("[OpenAiApiCache] cleared all cached OpenAiApi instances"); + } + + /** + * Gets the current cache size (for testing / monitoring). + * + * @return the number of cached entries + */ + public int size() { + return openAiApiMap.size(); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java index e4e3f470ed7c..40dfc31fa77f 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java @@ -23,7 +23,8 @@ import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; import org.apache.shenyu.common.enums.PluginEnum; import org.apache.shenyu.common.utils.GsonUtils; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache; import org.apache.shenyu.plugin.base.cache.CommonHandleCache; import org.apache.shenyu.plugin.base.handler.PluginDataHandler; import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; @@ -37,10 +38,7 @@ public class AiProxyPluginHandler implements PluginDataHandler { private final CommonHandleCache selectorCachedHandle = new CommonHandleCache<>(); - private final ChatClientCache chatClientCache; - - public AiProxyPluginHandler(final ChatClientCache chatClientCache) { - this.chatClientCache = chatClientCache; + public AiProxyPluginHandler() { } @Override @@ -53,10 +51,13 @@ public void handlerPlugin(final PluginData pluginData) { @Override public void handlerSelector(final SelectorData selectorData) { // Invalidate the cache first when the selector is updated. - chatClientCache.remove(selectorData.getId()); + OpenAiApiCache.getInstance().remove(selectorData.getId()); // Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events // with refreshed realApiKey after selector changes. Removing here introduces a window of misses. - if (Objects.isNull(selectorData.getHandle())) { + if (Objects.isNull(selectorData.getHandle()) || selectorData.getHandle().isEmpty()) { + // Clear stale cached handle when selector handle is removed/cleared + selectorCachedHandle + .removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE)); return; } AiProxyHandle aiProxyHandle = GsonUtils.getInstance().fromJson(selectorData.getHandle(), AiProxyHandle.class); @@ -68,7 +69,8 @@ public void handlerSelector(final SelectorData selectorData) { @Override public void removeSelector(final SelectorData selectorData) { // Invalidate the cache when the selector is removed. - chatClientCache.remove(selectorData.getId()); + OpenAiApiCache.getInstance().remove(selectorData.getId()); + AiProxyApiKeyCache.getInstance().removeBySelectorId(selectorData.getId()); selectorCachedHandle .removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE)); } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java index 71ab0c4c49d1..534f81f0c7f7 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java @@ -17,18 +17,24 @@ package org.apache.shenyu.plugin.ai.proxy.enhanced.service; -import org.apache.shenyu.plugin.ai.common.strategy.SimpleModelFallbackStrategy; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.NonTransientAiException; +import org.springframework.http.ResponseEntity; +import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.util.retry.Retry; import java.time.Duration; +import java.util.Objects; import java.util.Optional; /** @@ -39,78 +45,111 @@ public class AiProxyExecutorService { private static final Logger LOG = LoggerFactory.getLogger(AiProxyExecutorService.class); /** - * Execute the AI call with retry and fallback. + * Execute a streaming AI call directly via {@link OpenAiApi}, bypassing Spring AI's + * {@code createRequest()} which loses fields like {@code reasoning_content}. * - * @param mainClient the main chat client - * @param fallbackClientOpt the optional fallback chat client - * @param requestBody the request body - * @return a Mono containing the ChatResponse + * @param mainApi the main OpenAiApi + * @param fallbackCtxOpt the optional fallback context (api + config) + * @param request the ChatCompletionRequest with all fields preserved + * @param requestBody the original request body for rebuilding fallback request + * @param stream whether this is a streaming request + * @return a Flux of ChatCompletionChunk */ - public Mono execute(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) { - final Mono mainCall = doChatCall(mainClient, requestBody); + public Flux executeDirectStream(final OpenAiApi mainApi, + final Optional fallbackCtxOpt, final ChatCompletionRequest request, + final String requestBody, final boolean stream) { + return mainApi.chatCompletionStream(request) + .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct stream")) + .retryWhen(Retry.max(1) + .filter(AiProxyExecutorService::isRetryable) + .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> { + LOG.warn("Direct stream retry exhausted. Triggering fallback.", + retrySignal.failure()); + return new NonTransientAiException( + "Direct stream failed after 1 retry. Triggering fallback.", + retrySignal.failure()); + })) + .onErrorResume(e -> handleDirectFallbackStream(e, fallbackCtxOpt, requestBody, stream)); + } - return mainCall + /** + * Execute a non-streaming AI call directly via {@link OpenAiApi}. + * + * @param mainApi the main OpenAiApi + * @param fallbackCtxOpt the optional fallback context (api + config) + * @param request the ChatCompletionRequest with all fields preserved + * @param requestBody the original request body for rebuilding fallback request + * @return a Mono of ResponseEntity containing ChatCompletion + */ + public Mono> executeDirectCall(final OpenAiApi mainApi, + final Optional fallbackCtxOpt, final ChatCompletionRequest request, + final String requestBody) { + return Mono.fromCallable(() -> mainApi.chatCompletionEntity(request)) + .subscribeOn(Schedulers.boundedElastic()) + .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct call")) .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)) - .filter(throwable -> !(throwable instanceof NonTransientAiException)) + .filter(AiProxyExecutorService::isRetryable) .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> { - LOG.warn("Retries exhausted for AI call after {} attempts.", + LOG.warn("Direct call retries exhausted after {} attempts. Triggering fallback.", retrySignal.totalRetries(), retrySignal.failure()); - return new NonTransientAiException("Retries exhausted. Triggering fallback.", + return new NonTransientAiException("Direct call retries exhausted. Triggering fallback.", retrySignal.failure()); })) - .onErrorResume(NonTransientAiException.class, - throwable -> handleFallback(throwable, fallbackClientOpt, requestBody)); - } - - protected Mono doChatCall(final ChatClient client, final String requestBody) { - return Mono.fromCallable(() -> client.prompt().user(requestBody).call().chatResponse()) - .subscribeOn(Schedulers.boundedElastic()); + .onErrorResume(e -> handleDirectFallbackCall(e, fallbackCtxOpt, requestBody)); } - private Mono handleFallback(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) { - LOG.warn("AI main call failed or retries exhausted, attempting to fallback...", throwable); + private Flux handleDirectFallbackStream(final Throwable throwable, + final Optional fallbackCtxOpt, final String requestBody, final boolean stream) { + LOG.warn("Main direct stream failed, attempting fallback...", throwable); - if (fallbackClientOpt.isEmpty()) { - return Mono.error(throwable); + if (fallbackCtxOpt.isEmpty()) { + return Flux.error(throwable); } - return SimpleModelFallbackStrategy.INSTANCE.fallback(fallbackClientOpt.get(), requestBody, throwable); + final FallbackContext ctx = fallbackCtxOpt.get(); + LOG.info("Using fallback OpenAiApi for direct stream"); + final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest( + requestBody, stream, ctx.config()); + return ctx.api().chatCompletionStream(fallbackRequest); } - /** - * Execute the AI call with retry and fallback. - * - * @param mainClient the main chat client - * @param fallbackClientOpt the optional fallback chat client - * @param requestBody the request body - * @return a Flux containing the ChatResponse - */ - public Flux executeStream(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) { - final Flux mainStream = doChatStream(mainClient, requestBody); + private Mono> handleDirectFallbackCall(final Throwable throwable, + final Optional fallbackCtxOpt, final String requestBody) { + LOG.warn("Main direct call failed, attempting fallback...", throwable); - return mainStream - .retryWhen(Retry.max(1) - .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> { - LOG.warn("Retrying stream once failed. Attempts: {}. Triggering fallback.", - retrySignal.totalRetries(), retrySignal.failure()); - return new NonTransientAiException("Stream failed after 1 retry. Triggering fallback.", retrySignal.failure()); - })) - .onErrorResume(NonTransientAiException.class, - throwable -> handleFallbackStream(throwable, fallbackClientOpt, requestBody)); - } + if (fallbackCtxOpt.isEmpty()) { + return Mono.error(throwable); + } - protected Flux doChatStream(final ChatClient client, final String requestBody) { - return Flux.defer(() -> client.prompt().user(requestBody).stream().chatResponse()) + final FallbackContext ctx = fallbackCtxOpt.get(); + LOG.info("Using fallback OpenAiApi for direct call"); + final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest( + requestBody, false, ctx.config()); + return Mono.fromCallable(() -> ctx.api().chatCompletionEntity(fallbackRequest)) .subscribeOn(Schedulers.boundedElastic()); } - private Flux handleFallbackStream(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) { - LOG.warn("AI main stream failed or retries exhausted, attempting to fallback...", throwable); - - if (fallbackClientOpt.isEmpty()) { - return Flux.error(throwable); + /** + * Determine if the error is retryable. + * Retries transient network errors and retryable HTTP statuses (429, 5xx). + * Non-retryable: NonTransientAiException, client errors (400/401/403/404). + */ + private static boolean isRetryable(final Throwable throwable) { + if (throwable instanceof NonTransientAiException) { + return false; } + final WebClientResponseException webClientEx = UpstreamErrorLogger.findWebClientResponseException(throwable); + if (Objects.nonNull(webClientEx)) { + final int status = webClientEx.getStatusCode().value(); + return status == 429 || status >= 500; + } + return true; + } - return SimpleModelFallbackStrategy.INSTANCE.fallbackStream(fallbackClientOpt.get(), requestBody, throwable); + /** + * Container for fallback context: the OpenAiApi instance and the fallback config + * used to rebuild the request with fallback-specific model/temperature/maxTokens. + */ + public record FallbackContext(OpenAiApi api, AiCommonConfig config) { } -} \ No newline at end of file +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java new file mode 100644 index 000000000000..06a08e0c378a --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.service; + +import org.slf4j.Logger; +import org.springframework.web.reactive.function.client.WebClientResponseException; + +import java.util.Objects; + +/** + * Shared utility for logging upstream AI service errors with WebClientResponseException details. + */ +public final class UpstreamErrorLogger { + + private static final int MAX_BODY_LOG_LENGTH = 512; + + private UpstreamErrorLogger() { + } + + public static void logUpstreamError(final Logger log, final Throwable e, final String mode) { + if (Objects.isNull(e)) { + return; + } + final WebClientResponseException webClientEx = findWebClientResponseException(e); + if (Objects.nonNull(webClientEx)) { + log.error("[AiProxy] {} failed, status={}, upstreamBody={}", + mode, webClientEx.getStatusCode(), truncateBody(webClientEx.getResponseBodyAsString()), e); + } else { + log.error("[AiProxy] {} failed", mode, e); + } + } + + private static String truncateBody(final String body) { + if (Objects.isNull(body)) { + return "null"; + } + if (body.length() <= MAX_BODY_LOG_LENGTH) { + return body; + } + return body.substring(0, MAX_BODY_LOG_LENGTH) + "...(truncated, total " + body.length() + " chars)"; + } + + /** + * Find the {@link WebClientResponseException} in the exception chain. + * + * @param e the root exception + * @return the WebClientResponseException if found, null otherwise + */ + public static WebClientResponseException findWebClientResponseException(final Throwable e) { + Throwable current = e; + while (Objects.nonNull(current)) { + if (current instanceof WebClientResponseException ex) { + return ex; + } + current = current.getCause(); + } + return null; + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java index 4f63f5b53bb8..bde74f48668e 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java @@ -19,7 +19,7 @@ import org.apache.shenyu.common.dto.ProxyApiKeyData; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache; import org.apache.shenyu.sync.data.api.AiProxyApiKeyDataSubscriber; import java.util.Objects; @@ -29,12 +29,6 @@ */ public final class CommonAiProxyApiKeyDataSubscriber implements AiProxyApiKeyDataSubscriber { - private final ChatClientCache chatClientCache; - - public CommonAiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) { - this.chatClientCache = chatClientCache; - } - @Override public void onSubscribe(final ProxyApiKeyData data) { if (Objects.isNull(data) || Objects.isNull(data.getProxyApiKey())) { @@ -54,6 +48,6 @@ public void unSubscribe(final ProxyApiKeyData data) { @Override public void refresh() { AiProxyApiKeyCache.getInstance().refresh(); - chatClientCache.clearAll(); + OpenAiApiCache.getInstance().clearAll(); } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java index 30e84abc0809..e61f43e1689d 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java @@ -23,10 +23,7 @@ import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; import org.apache.shenyu.common.enums.AiModelProviderEnum; import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; -import org.apache.shenyu.plugin.ai.common.spring.ai.AiModelFactory; -import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; @@ -43,21 +40,22 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.context.ApplicationContext; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.util.Optional; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -75,27 +73,12 @@ public class AiProxyPluginTest { private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}"; - @Mock - private AiModelFactoryRegistry registry; - @Mock private AiProxyConfigService configService; @Mock private AiProxyExecutorService executorService; - @Mock - private ChatClientCache chatClientCache; - - @Mock - private AiModelFactory modelFactory; - - @Mock - private ChatModel chatModel; - - @Mock - private ChatClient chatClient; - private AiProxyPluginHandler aiProxyPluginHandler; private AiProxyPlugin plugin; @@ -110,8 +93,8 @@ public class AiProxyPluginTest { @BeforeEach public void setUp() { - aiProxyPluginHandler = new AiProxyPluginHandler(chatClientCache); - plugin = new AiProxyPlugin(registry, configService, executorService, chatClientCache, aiProxyPluginHandler); + aiProxyPluginHandler = new AiProxyPluginHandler(); + plugin = new AiProxyPlugin(configService, executorService, aiProxyPluginHandler); selector = new SelectorData(); selector.setId(SELECTOR_ID); @@ -129,14 +112,6 @@ public void setUp() { when(applicationContext.getBean(ShenyuResult.class)).thenReturn(resultMock); SpringBeanUtils.getInstance().setApplicationContext(applicationContext); - // Common mock behavior for model creation and cache - when(registry.getFactory(any(AiModelProviderEnum.class))).thenReturn(modelFactory); - when(modelFactory.createAiModel(any(AiCommonConfig.class))).thenReturn(chatModel); - when(chatClientCache.computeIfAbsent(anyString(), any())).thenAnswer(invocation -> { - // Always return the mock ChatClient to avoid any UnsupportedOperationException - return chatClient; - }); - // mock static apiKeyCacheMockedStatic = mockStatic(AiProxyApiKeyCache.class); } @@ -150,13 +125,14 @@ public void tearDown() { private void setupSuccessMocks(final AiProxyHandle handle, final AiCommonConfig primaryConfig, final Optional fallbackConfig) { aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle); - final ChatResponse chatResponse = mock(ChatResponse.class); + final ChatCompletion chatCompletion = mock(ChatCompletion.class); + final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion); when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(fallbackConfig); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(fallbackConfig); - when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0)); - when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity)); + when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class), any(Boolean.class))).thenReturn(Flux.empty()); } @Test @@ -164,6 +140,8 @@ public void testExecuteSuccess() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); setupSuccessMocks(handle, primaryConfig, Optional.empty()); StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) @@ -173,7 +151,7 @@ public void testExecuteSuccess() { verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).execute(any(), any(), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test @@ -181,8 +159,12 @@ public void testExecuteWithDynamicFallback() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final AiCommonConfig fallbackConfig = new AiCommonConfig(); fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName()); + fallbackConfig.setBaseUrl("https://api.deepseek.com"); + fallbackConfig.setApiKey("fallback-key"); setupSuccessMocks(handle, primaryConfig, Optional.of(fallbackConfig)); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.of(fallbackConfig)); @@ -190,7 +172,7 @@ public void testExecuteWithDynamicFallback() { StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - verify(executorService).execute(any(ChatClient.class), any(Optional.class), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test @@ -199,6 +181,7 @@ public void testExecuteWithValidProxyApiKey() { handle.setProxyEnabled("true"); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); primaryConfig.setApiKey("original-key"); // setup request with proxy key @@ -225,6 +208,8 @@ public void testExecuteWithInvalidProxyApiKey() { handle.setProxyEnabled("true"); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); // setup request with proxy key exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/test").header(Constants.X_API_KEY, "proxy-key-invalid").body(REQUEST_BODY)); @@ -252,8 +237,12 @@ public void testExecuteWithAdminFallback() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final AiCommonConfig fallbackConfig = new AiCommonConfig(); fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName()); + fallbackConfig.setBaseUrl("https://api.deepseek.com"); + fallbackConfig.setApiKey("fallback-key"); setupSuccessMocks(handle, primaryConfig, Optional.empty()); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty()); @@ -262,61 +251,57 @@ public void testExecuteWithAdminFallback() { StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - verify(executorService).execute(any(ChatClient.class), any(Optional.class), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test - public void testCacheIsUsedForAdminFallbackClient() { + public void testAdminFallbackExecution() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final AiCommonConfig fallbackConfig = new AiCommonConfig(); fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName()); + fallbackConfig.setBaseUrl("https://api.deepseek.com"); + fallbackConfig.setApiKey("fallback-key"); // Cache the handle for the test aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle); - final ChatResponse chatResponse = mock(ChatResponse.class); + final ChatCompletion chatCompletion = mock(ChatCompletion.class); + final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion); // Setup all necessary mocks when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty()); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig)); - when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0)); - when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity)); - // Execute the test - focus on successful execution rather than cache verification + // Execute the test - verify admin fallback path is exercised StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - + // Verify that the configuration methods were called correctly verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).execute(any(), any(), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test - public void testCreateChatModelThrowsException() { + public void testCreateOpenAiApiThrowsException() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); - primaryConfig.setProvider("InvalidProvider"); - - // Cache the handle for the test + primaryConfig.setBaseUrl(null); + aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle); - - // Mock config service to return the invalid config + when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty()); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.empty()); - when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0)); - - // Mock registry to return null factory for invalid provider - this should cause IllegalArgumentException - when(registry.getFactory(any())).thenReturn(null); - - // Mock executorService to return a proper Mono to avoid NullPointerException - when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(new IllegalArgumentException("AI model factory not found"))); StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) - .expectError(IllegalArgumentException.class) + .expectErrorMatches(e -> e instanceof IllegalArgumentException + && e.getMessage().contains("baseUrl must not be empty")) .verify(); } @@ -325,10 +310,12 @@ public void testExecutorServiceError() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final RuntimeException exception = new RuntimeException("AI execution failed"); setupSuccessMocks(handle, primaryConfig, Optional.empty()); - when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(exception)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.error(exception)); StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .expectErrorMatches(exception::equals) @@ -337,6 +324,6 @@ public void testExecutorServiceError() { verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).execute(any(), any(), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } } \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java new file mode 100644 index 000000000000..11bed2897ddb --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.cache; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.api.OpenAiApi; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +class OpenAiApiCacheTest { + + private OpenAiApiCache cache; + + @BeforeEach + void setUp() { + cache = new OpenAiApiCache(); + cache.clearAll(); + } + + @Test + void testComputeIfAbsentReusesExisting() { + final OpenAiApi api = createTestApi("https://api.openai.com", "key1"); + final OpenAiApi cached = cache.computeIfAbsent("key1", () -> api); + final OpenAiApi reused = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1-other")); + + assertSame(cached, reused); + assertEquals(1, cache.size()); + } + + @Test + void testComputeIfAbsentCreatesNewForDifferentKey() { + final OpenAiApi api1 = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1")); + final OpenAiApi api2 = cache.computeIfAbsent("key2", () -> createTestApi("https://api.deepseek.com", "key2")); + + // Different keys produce different instances + assertEquals(2, cache.size()); + } + + @Test + void testRemoveBySelectorId() { + cache.computeIfAbsent("sel1|main_123", () -> createTestApi("https://a.com", "k1")); + cache.computeIfAbsent("sel1|adminFallback_456", () -> createTestApi("https://b.com", "k2")); + cache.computeIfAbsent("sel2|main_789", () -> createTestApi("https://c.com", "k3")); + + assertEquals(3, cache.size()); + + cache.remove("sel1"); + + assertEquals(1, cache.size()); + } + + @Test + void testRemoveWithNullSelectorIdDoesNotThrow() { + cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1")); + cache.remove(null); + assertEquals(1, cache.size()); + } + + @Test + void testClearAll() { + cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1")); + cache.computeIfAbsent("key2", () -> createTestApi("https://b.com", "k2")); + + assertEquals(2, cache.size()); + + cache.clearAll(); + + assertEquals(0, cache.size()); + } + + private OpenAiApi createTestApi(final String baseUrl, final String apiKey) { + return OpenAiApi.builder().baseUrl(baseUrl).apiKey(apiKey).build(); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java new file mode 100644 index 000000000000..5aa2a7a5edf2 --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.handler; + +import org.apache.shenyu.common.constant.Constants; +import org.apache.shenyu.common.dto.SelectorData; +import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; +import org.apache.shenyu.common.enums.PluginEnum; +import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +class AiProxyPluginHandlerTest { + + private AiProxyPluginHandler handler; + + @BeforeEach + void setUp() { + handler = new AiProxyPluginHandler(); + } + + @Test + void testPluginNamed() { + assertEquals(PluginEnum.AI_PROXY.getName(), handler.pluginNamed()); + } + + @Test + void testHandlerSelectorCachesHandle() { + SelectorData selector = new SelectorData(); + selector.setId("sel-1"); + selector.setHandle("{\"provider\":\"open_ai\",\"baseUrl\":\"https://api.openai.com\",\"apiKey\":\"sk-test\",\"model\":\"gpt-4\"}"); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-1", Constants.DEFAULT_RULE); + AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key); + assertNotNull(cached); + assertEquals("open_ai", cached.getProvider()); + assertEquals("https://api.openai.com", cached.getBaseUrl()); + assertEquals("sk-test", cached.getApiKey()); + assertEquals("gpt-4", cached.getModel()); + } + + @Test + void testHandlerSelectorWithNullHandle() { + SelectorData selector = new SelectorData(); + selector.setId("sel-2"); + selector.setHandle(null); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-2", Constants.DEFAULT_RULE); + assertNull(handler.getSelectorCachedHandle().obtainHandle(key)); + } + + @Test + void testHandlerSelectorWithEmptyHandle() { + SelectorData selector = new SelectorData(); + selector.setId("sel-3"); + selector.setHandle(""); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-3", Constants.DEFAULT_RULE); + assertNull(handler.getSelectorCachedHandle().obtainHandle(key)); + } + + @Test + void testHandlerSelectorWithFallbackNormalize() { + SelectorData selector = new SelectorData(); + selector.setId("sel-4"); + selector.setHandle("{\"provider\":\"open_ai\",\"fallbackEnabled\":\"true\",\"fallbackModel\":\"fallback-model\"}"); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-4", Constants.DEFAULT_RULE); + AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key); + assertNotNull(cached); + assertNotNull(cached.getFallbackConfig()); + assertEquals("fallback-model", cached.getFallbackConfig().getModel()); + } + + @Test + void testRemoveSelector() { + SelectorData selector = new SelectorData(); + selector.setId("sel-5"); + selector.setHandle("{\"provider\":\"open_ai\",\"model\":\"gpt-4\"}"); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-5", Constants.DEFAULT_RULE); + assertNotNull(handler.getSelectorCachedHandle().obtainHandle(key)); + + handler.removeSelector(selector); + assertNull(handler.getSelectorCachedHandle().obtainHandle(key)); + } + + @Test + void testHandlerPluginDoesNothing() { + handler.handlerPlugin(null); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java index 7bd2bbbdac0d..38597d0dbedf 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java @@ -17,24 +17,24 @@ package org.apache.shenyu.plugin.ai.proxy.enhanced.service; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.NonTransientAiException; -import org.springframework.web.client.RestClientException; +import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; import java.util.Optional; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -42,102 +42,106 @@ @ExtendWith(MockitoExtension.class) public class AiProxyExecutorServiceTest { - @Mock - private ChatModel mainChatModel; - - @Mock - private ChatModel fallbackChatModel; - - private ChatClient mainClient; - - private ChatClient fallbackClient; + private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; private AiProxyExecutorService executorService; @BeforeEach void setUp() { - executorService = spy(new AiProxyExecutorService()); - mainClient = ChatClient.create(mainChatModel); - fallbackClient = ChatClient.create(fallbackChatModel); + executorService = new AiProxyExecutorService(); } @Test - void testExecuteSuccessOnFirstAttempt() { - final ChatResponse successResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))).thenReturn(successResponse); - - StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request")) - .expectNext(successResponse) + void testExecuteDirectStreamSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionChunk chunk = mock(ChatCompletionChunk.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + when(mainApi.chatCompletionStream(request)).thenReturn(Flux.just(chunk)); + + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true)) + .expectNext(chunk) .verifyComplete(); - verify(mainChatModel, times(1)).call(any(Prompt.class)); + verify(mainApi, times(1)).chatCompletionStream(request); } @Test - void testExecuteRetryOnRestClientExceptionAndSucceed() { - final ChatResponse successResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))) - .thenThrow(new RestClientException("transient error")) - .thenReturn(successResponse); - - StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request")) - .expectNext(successResponse) - .verifyComplete(); + void testExecuteDirectStreamErrorWithNoFallback() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error"))); - verify(mainChatModel, times(2)).call(any(Prompt.class)); + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true)) + .expectError(NonTransientAiException.class) + .verify(); } @Test - void testExecuteRetryExhaustedThenFallback() { - final ChatResponse fallbackResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))).thenThrow(new RestClientException("transient error")); - when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse); + void testExecuteDirectStreamFallbackSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final OpenAiApi fallbackApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + final ChatCompletionChunk fallbackChunk = mock(ChatCompletionChunk.class); - StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request")) - .expectNext(fallbackResponse) + when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error"))); + when(fallbackApi.chatCompletionStream(any(ChatCompletionRequest.class))).thenReturn(Flux.just(fallbackChunk)); + + final AiCommonConfig fallbackConfig = new AiCommonConfig(); + fallbackConfig.setModel("fallback-model"); + final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig); + + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(ctx), request, REQUEST_BODY, true)) + .expectNext(fallbackChunk) .verifyComplete(); - verify(mainChatModel, times(4)).call(any(Prompt.class)); - verify(fallbackChatModel, times(1)).call(any(Prompt.class)); + verify(fallbackApi, times(1)).chatCompletionStream(any(ChatCompletionRequest.class)); } @Test - void testExecuteNonTransientExceptionTriggersFallbackDirectly() { - final ChatResponse fallbackResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error")); - when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse); - - StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request")) - .expectNext(fallbackResponse) + void testExecuteDirectCallSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + final ChatCompletion completion = mock(ChatCompletion.class); + final ResponseEntity responseEntity = ResponseEntity.ok(completion); + when(mainApi.chatCompletionEntity(request)).thenReturn(responseEntity); + + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY)) + .expectNext(responseEntity) .verifyComplete(); - verify(mainChatModel, times(1)).call(any(Prompt.class)); - verify(fallbackChatModel, times(1)).call(any(Prompt.class)); + verify(mainApi, times(1)).chatCompletionEntity(request); } @Test - void testExecuteFallbackFails() { - final RestClientException fallbackException = new RestClientException("fallback failed"); - when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error")); - when(fallbackChatModel.call(any(Prompt.class))).thenThrow(fallbackException); + void testExecuteDirectCallErrorWithNoFallback() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error")); - StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request")) - .expectErrorMatches(e -> e == fallbackException) + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY)) + .expectError(NonTransientAiException.class) .verify(); - - verify(mainChatModel, times(1)).call(any(Prompt.class)); - verify(fallbackChatModel, times(1)).call(any(Prompt.class)); } @Test - void testExecuteNoFallbackProvided() { - final NonTransientAiException mainException = new NonTransientAiException("non-transient error"); - when(mainChatModel.call(any(Prompt.class))).thenThrow(mainException); + void testExecuteDirectCallFallbackSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final OpenAiApi fallbackApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + final ChatCompletion fallbackCompletion = mock(ChatCompletion.class); + final ResponseEntity fallbackResponse = ResponseEntity.ok(fallbackCompletion); - StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request")) - .expectErrorMatches(e -> e == mainException) - .verify(); + when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error")); + when(fallbackApi.chatCompletionEntity(any(ChatCompletionRequest.class))).thenReturn(fallbackResponse); + + final AiCommonConfig fallbackConfig = new AiCommonConfig(); + fallbackConfig.setModel("fallback-model"); + final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig); + + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(ctx), request, REQUEST_BODY)) + .expectNext(fallbackResponse) + .verifyComplete(); - verify(mainChatModel, times(1)).call(any(Prompt.class)); + verify(fallbackApi, times(1)).chatCompletionEntity(any(ChatCompletionRequest.class)); } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java new file mode 100644 index 000000000000..85193687d1df --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.service; + +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.springframework.http.HttpStatus; +import org.springframework.web.reactive.function.client.WebClientResponseException; + +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +class UpstreamErrorLoggerTest { + + @Test + void testLogWithWebClientResponseException() { + Logger log = mock(Logger.class); + WebClientResponseException ex = WebClientResponseException.create( + 429, "Too Many Requests", null, "rate limited".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8); + + UpstreamErrorLogger.logUpstreamError(log, ex, "stream"); + + verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}", + "stream", HttpStatus.TOO_MANY_REQUESTS, "rate limited", ex); + } + + @Test + void testLogWithWrappedWebClientResponseException() { + Logger log = mock(Logger.class); + WebClientResponseException inner = WebClientResponseException.create( + 500, "Internal Server Error", null, "server error".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8); + RuntimeException wrapped = new RuntimeException("call failed", inner); + + UpstreamErrorLogger.logUpstreamError(log, wrapped, "non-stream"); + + verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}", + "non-stream", HttpStatus.INTERNAL_SERVER_ERROR, "server error", wrapped); + } + + @Test + void testLogWithGenericException() { + Logger log = mock(Logger.class); + RuntimeException ex = new RuntimeException("generic error"); + + UpstreamErrorLogger.logUpstreamError(log, ex, "stream"); + + verify(log).error("[AiProxy] {} failed", "stream", ex); + } + + @Test + void testLogWithNullExceptionDoesNotThrow() { + Logger log = mock(Logger.class); + assertDoesNotThrow(() -> UpstreamErrorLogger.logUpstreamError(log, null, "stream")); + } + + @Test + void testLogWithLongBodyIsTruncated() { + Logger log = mock(Logger.class); + final String longBody = "x".repeat(1024); + WebClientResponseException ex = WebClientResponseException.create( + 500, "Internal Server Error", null, longBody.getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8); + + UpstreamErrorLogger.logUpstreamError(log, ex, "stream"); + + final String expectedTruncated = longBody.substring(0, 512) + + "...(truncated, total " + longBody.length() + " chars)"; + verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}", + "stream", HttpStatus.INTERNAL_SERVER_ERROR, expectedTruncated, ex); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java index c863ba20f48a..4827c5337240 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java @@ -19,17 +19,12 @@ import org.apache.shenyu.common.dto.ProxyApiKeyData; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import static org.mockito.Mockito.mock; - class CommonAiProxyApiKeyDataSubscriberTest { - private final ChatClientCache chatClientCache = mock(ChatClientCache.class); - @AfterEach void cleanup() { AiProxyApiKeyCache.getInstance().refresh(); @@ -37,7 +32,7 @@ void cleanup() { @Test void testOnSubscribeCachesEnabled() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-sub") @@ -51,7 +46,7 @@ void testOnSubscribeCachesEnabled() { @Test void testOnSubscribeIgnoresDisabled() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-disabled") @@ -65,7 +60,7 @@ void testOnSubscribeIgnoresDisabled() { @Test void testUnSubscribeRemoves() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-unsub") @@ -81,7 +76,7 @@ void testUnSubscribeRemoves() { @Test void testRefreshClearsCache() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-refresh") @@ -98,14 +93,14 @@ void testRefreshClearsCache() { @Test void testOnSubscribeNullDataNoThrow() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); subscriber.onSubscribe(null); Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size()); } @Test void testOnSubscribeNullKeyIgnored() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey(null) @@ -119,14 +114,14 @@ void testOnSubscribeNullKeyIgnored() { @Test void testUnSubscribeNullDataNoThrow() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); subscriber.unSubscribe(null); Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size()); } @Test void testUnSubscribeNullKeyNoThrow() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey(null) @@ -140,7 +135,7 @@ void testUnSubscribeNullKeyNoThrow() { @Test void testDuplicateSubscribeOverrides() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData v1 = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("dup-key") @@ -161,48 +156,4 @@ void testDuplicateSubscribeOverrides() { subscriber.onSubscribe(v2); Assertions.assertEquals("real-2", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "dup-key")); } - - @Test - void testSubscribeDisabledDoesNotOverrideExisting() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); - ProxyApiKeyData enabled = ProxyApiKeyData.builder() - .selectorId("test-selector") - .proxyApiKey("keep-key") - .realApiKey("real-keep") - .enabled(Boolean.TRUE) - .namespaceId("default") - .build(); - subscriber.onSubscribe(enabled); - Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key")); - - ProxyApiKeyData disabled = ProxyApiKeyData.builder() - .selectorId("test-selector") - .proxyApiKey("keep-key") - .realApiKey("real-new") - .enabled(Boolean.FALSE) - .namespaceId("default") - .build(); - subscriber.onSubscribe(disabled); - // disabled subscribe should not override existing cached mapping - Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key")); - } - - @Test - void testVeryLongProxyKey() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < 300; i++) { - sb.append('A' + (i % 26)); - } - String longKey = sb.toString(); - ProxyApiKeyData data = ProxyApiKeyData.builder() - .selectorId("test-selector") - .proxyApiKey(longKey) - .realApiKey("real-long") - .enabled(Boolean.TRUE) - .namespaceId("default") - .build(); - subscriber.onSubscribe(data); - Assertions.assertEquals("real-long", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", longKey)); - } -} \ No newline at end of file +} diff --git a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java index d017cfabd8fc..2bcb70aa1dae 100644 --- a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java +++ b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java @@ -22,7 +22,6 @@ import org.apache.shenyu.plugin.ai.common.spring.ai.factory.OpenAiModelFactory; import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry; import org.apache.shenyu.plugin.ai.proxy.enhanced.AiProxyPlugin; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; @@ -48,42 +47,30 @@ public class AiProxyPluginConfiguration { /** * Ai proxy plugin. * - * @param aiModelFactoryRegistry the aiModelFactoryRegistry * @param aiProxyConfigService the aiProxyConfigService * @param aiProxyExecutorService the aiProxyExecutorService - * @param chatClientCache the chatClientCache * @param aiProxyPluginHandler the aiProxyPluginHandler * @return the shenyu plugin */ @Bean public ShenyuPlugin aiProxyPlugin( - final AiModelFactoryRegistry aiModelFactoryRegistry, final AiProxyConfigService aiProxyConfigService, final AiProxyExecutorService aiProxyExecutorService, - final ChatClientCache chatClientCache, final AiProxyPluginHandler aiProxyPluginHandler) { return new AiProxyPlugin( - aiModelFactoryRegistry, aiProxyConfigService, aiProxyExecutorService, - chatClientCache, aiProxyPluginHandler); } /** * Ai proxy plugin handler. * - * @param chatClientCache the chatClientCache * @return the shenyu plugin handler */ @Bean - public AiProxyPluginHandler aiProxyPluginHandler(final ChatClientCache chatClientCache) { - return new AiProxyPluginHandler(chatClientCache); - } - - @Bean - public ChatClientCache chatClientCache() { - return new ChatClientCache(); + public AiProxyPluginHandler aiProxyPluginHandler() { + return new AiProxyPluginHandler(); } @Bean @@ -131,11 +118,10 @@ public DeepSeekModelFactory deepSeekModelFactory() { /** * Ai proxy api key data subscriber. * - * @param chatClientCache the chatClientCache * @return the subscriber */ @Bean - public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) { - return new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber() { + return new CommonAiProxyApiKeyDataSubscriber(); } }