From 03b40be500bcf07f72d5637f96165cf4d4b7c23d Mon Sep 17 00:00:00 2001 From: eye-gu <734164350@qq.com> Date: Wed, 13 May 2026 22:29:31 +0800 Subject: [PATCH 1/6] refactor: use OpenAiApi directly to return OpenAI Chat Completions format --- .../ai/common/config/AiCommonConfig.java | 2 +- .../protocol/OpenAiProtocolAdapter.java | 118 +++++++++ .../ai/common/strategy/FallbackStrategy.java | 49 ---- .../strategy/SimpleModelFallbackStrategy.java | 68 ----- .../protocol/OpenAiProtocolAdapterTest.java | 238 ++++++++++++++++++ .../ai/proxy/enhanced/AiProxyPlugin.java | 163 ++++-------- .../proxy/enhanced/cache/ChatClientCache.java | 143 ----------- .../handler/AiProxyPluginHandler.java | 13 - .../service/AiProxyExecutorService.java | 111 ++++---- .../enhanced/service/UpstreamErrorLogger.java | 53 ++++ .../CommonAiProxyApiKeyDataSubscriber.java | 8 - .../ai/proxy/enhanced/AiProxyPluginTest.java | 105 ++++---- .../service/AiProxyExecutorServiceTest.java | 134 +++++----- ...CommonAiProxyApiKeyDataSubscriberTest.java | 69 +---- .../ai/proxy/AiProxyPluginConfiguration.java | 22 +- 15 files changed, 642 insertions(+), 654 deletions(-) create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java delete mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java delete mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java delete mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java index c3d30221efe0..71a62e15636f 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java @@ -50,7 +50,7 @@ public class AiCommonConfig { /** * temperature. */ - private Double temperature = 0.8; + private Double temperature; /** * max tokens. diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java new file mode 100644 index 000000000000..ce44a5d3178b --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.common.protocol; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.shenyu.common.utils.JsonUtils; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; + +import java.util.Objects; + +/** + * Adapts between OpenAI Chat Completions wire format and internal representations. + */ +public final class OpenAiProtocolAdapter { + + private OpenAiProtocolAdapter() { + } + + /** + * Resolve the stream flag: client request body takes priority, + * falls back to the provided default value. + * + * @param requestBody the raw JSON request body + * @param fallbackStream the default stream value from admin config + * @return true if streaming, false otherwise + */ + public static boolean resolveStream(final String requestBody, final Boolean fallbackStream) { + if (Objects.isNull(requestBody) || requestBody.isEmpty()) { + return Boolean.TRUE.equals(fallbackStream); + } + final JsonNode root = JsonUtils.toJsonNode(requestBody); + if (Objects.isNull(root)) { + return Boolean.TRUE.equals(fallbackStream); + } + if (root.hasNonNull("stream")) { + return root.get("stream").asBoolean(); + } + return Boolean.TRUE.equals(fallbackStream); + } + + /** + * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields + * including reasoning_content in assistant messages. + * + *

Spring AI's createRequest() loses reasoning_content, refusal, and annotations + * when reconstructing ChatCompletionMessage from AssistantMessage. + * This method avoids that loss by deserializing the raw JSON directly. + * + *

Also converts max_completion_tokens to max_tokens for broader API compatibility. + * + * @param requestBody the raw JSON request body in OpenAI Chat Completions format + * @param stream whether this is a streaming request (sets the stream field) + * @return a ChatCompletionRequest with all fields preserved from the original request + */ + public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, final boolean stream) { + return toChatCompletionRequest(requestBody, stream, null); + } + + /** + * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields. + * For model, temperature, max_tokens: client request takes priority; + * if missing, falls back to the corresponding field in fallbackConfig. + * + * @param requestBody the raw JSON request body in OpenAI Chat Completions format + * @param stream whether this is a streaming request + * @param fallbackConfig the admin config used as fallback when client omits fields + * @return a ChatCompletionRequest with all fields preserved + */ + public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, + final boolean stream, final AiCommonConfig fallbackConfig) { + if (Objects.isNull(requestBody) || requestBody.isEmpty()) { + throw new IllegalArgumentException("Request body must not be empty"); + } + final JsonNode root = JsonUtils.toJsonNode(requestBody); + if (Objects.isNull(root) || !root.isObject()) { + throw new IllegalArgumentException("Invalid request body: expected a JSON object"); + } + final ObjectNode mutableRoot = (ObjectNode) root; + + if (root.hasNonNull("max_completion_tokens") && !root.hasNonNull("max_tokens")) { + mutableRoot.put("max_tokens", root.get("max_completion_tokens").asInt()); + mutableRoot.remove("max_completion_tokens"); + } + + if (Objects.nonNull(fallbackConfig)) { + if (!root.hasNonNull("model") && Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) { + mutableRoot.put("model", fallbackConfig.getModel()); + } + if (!root.hasNonNull("temperature") && Objects.nonNull(fallbackConfig.getTemperature())) { + mutableRoot.put("temperature", fallbackConfig.getTemperature()); + } + if (!root.hasNonNull("max_tokens") && Objects.nonNull(fallbackConfig.getMaxTokens())) { + mutableRoot.put("max_tokens", fallbackConfig.getMaxTokens()); + } + } + + mutableRoot.put("stream", stream); + + return JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java deleted file mode 100644 index 08c68c4f46ec..000000000000 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.shenyu.plugin.ai.common.strategy; - -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatResponse; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * Fallback strategy. - */ -public interface FallbackStrategy { - - /** - * Execute the fallback strategy. - * - * @param fallbackClient the pre-configured and cached chat client for fallback - * @param requestBody the original request body as a string - * @param originalError the original error that triggered the fallback - * @return a Mono containing the fallback ChatResponse - */ - Mono fallback(ChatClient fallbackClient, String requestBody, Throwable originalError); - - /** - * Execute the fallback strategy for stream. - * - * @param fallbackClient the pre-configured and cached chat client for fallback - * @param requestBody the original request body as a string - * @param originalError the original error that triggered the fallback - * @return a Flux containing the fallback ChatResponse - */ - Flux fallbackStream(ChatClient fallbackClient, String requestBody, Throwable originalError); -} \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java deleted file mode 100644 index 94b3653e654d..000000000000 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.shenyu.plugin.ai.common.strategy; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatResponse; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - -/** - * A fallback strategy that simply executes a call with a pre-configured client. - */ -public final class SimpleModelFallbackStrategy implements FallbackStrategy { - - /** - * The constant INSTANCE. - */ - public static final FallbackStrategy INSTANCE = new SimpleModelFallbackStrategy(); - - private static final Logger LOG = LoggerFactory.getLogger(SimpleModelFallbackStrategy.class); - - private SimpleModelFallbackStrategy() { - } - - @Override - public Mono fallback(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) { - LOG.warn("Executing simple model fallback strategy due to error: {}", originalError.getMessage()); - - return Mono.fromCallable(() -> fallbackClient.prompt() - .user(requestBody) - .call() - .chatResponse()) - .subscribeOn(Schedulers.boundedElastic()) - .doOnSuccess(response -> LOG.info("Fallback call successful.")) - .doOnError(fallbackError -> LOG.error("Fallback call also failed.", fallbackError)); - } - - @Override - public Flux fallbackStream(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) { - LOG.warn("Executing simple model fallback stream strategy due to error: {}", originalError.getMessage()); - - return Flux.defer(() -> fallbackClient.prompt() - .user(requestBody) - .stream() - .chatResponse()) - .subscribeOn(Schedulers.boundedElastic()) - .doOnComplete(() -> LOG.info("Fallback stream completed successfully.")) - .doOnError(fallbackError -> LOG.error("Fallback stream also failed.", fallbackError)); - } -} \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java new file mode 100644 index 000000000000..566824318014 --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.common.protocol; + +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public final class OpenAiProtocolAdapterTest { + + private static final String BASE_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}"; + + @Test + void testNullRequestBody() { + assertThrows(IllegalArgumentException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest(null, false)); + } + + @Test + void testEmptyRequestBody() { + assertThrows(IllegalArgumentException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest("", false)); + } + + @Test + void testStreamFlagSet() { + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true); + assertNotNull(req); + assertTrue(req.stream()); + } + + @Test + void testResolveStreamClientTrueOverridesFallbackFalse() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":true}"; + assertTrue(OpenAiProtocolAdapter.resolveStream(body, false)); + } + + @Test + void testResolveStreamClientFalseOverridesFallbackTrue() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":false}"; + assertFalse(OpenAiProtocolAdapter.resolveStream(body, true)); + } + + @Test + void testResolveStreamFallbackWhenClientMissing() { + assertTrue(OpenAiProtocolAdapter.resolveStream(BASE_BODY, true)); + assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, false)); + } + + @Test + void testResolveStreamFallbackWhenClientMissingAndConfigNull() { + assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, null)); + } + + @Test + void testResolveStreamFallbackWhenBodyEmpty() { + assertTrue(OpenAiProtocolAdapter.resolveStream("", true)); + assertFalse(OpenAiProtocolAdapter.resolveStream("", false)); + } + + @Test + void testStreamFlagUnset() { + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false); + assertNotNull(req); + assertFalse(req.stream()); + } + + @Test + void testMaxCompletionTokensConvertedToMaxTokens() { + final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":100}"; + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false); + assertNotNull(req); + assertEquals(100, req.maxTokens()); + } + + @Test + void testClientModelTakesPriority() { + final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("admin-model"); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals("client-model", req.model()); + } + + @Test + void testFallbackModelWhenClientMissing() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("admin-model"); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertEquals("admin-model", req.model()); + } + + @Test + void testNoFallbackModelWhenConfigIsNull() { + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, null); + assertNotNull(req); + } + + @Test + void testNoFallbackWhenConfigModelIsNull() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel(null); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testNoFallbackWhenConfigModelIsEmpty() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel(""); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testClientTemperatureTakesPriority() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.5}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setTemperature(0.9); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals(0.5, req.temperature(), 0.001); + } + + @Test + void testFallbackTemperatureWhenClientMissing() { + final AiCommonConfig config = new AiCommonConfig(); + config.setTemperature(0.7); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertEquals(0.7, req.temperature(), 0.001); + } + + @Test + void testNoFallbackTemperatureWhenConfigNull() { + final AiCommonConfig config = new AiCommonConfig(); + config.setTemperature(null); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testClientMaxTokensTakesPriority() { + final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_tokens\":200}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setMaxTokens(500); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals(200, req.maxTokens()); + } + + @Test + void testFallbackMaxTokensWhenClientMissing() { + final AiCommonConfig config = new AiCommonConfig(); + config.setMaxTokens(500); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertEquals(500, req.maxTokens()); + } + + @Test + void testNoFallbackMaxTokensWhenConfigNull() { + final AiCommonConfig config = new AiCommonConfig(); + config.setMaxTokens(null); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config); + assertNotNull(req); + } + + @Test + void testAllFallbackFieldsAppliedWhenClientMissingAll() { + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("admin-model"); + config.setTemperature(0.3); + config.setMaxTokens(1024); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true, config); + assertEquals("admin-model", req.model()); + assertEquals(0.3, req.temperature(), 0.001); + assertEquals(1024, req.maxTokens()); + assertTrue(req.stream()); + } + + @Test + void testAllClientFieldsTakePriority() { + final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]," + + "\"temperature\":0.1,\"max_tokens\":50}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("admin-model"); + config.setTemperature(0.9); + config.setMaxTokens(9999); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals("client-model", req.model()); + assertEquals(0.1, req.temperature(), 0.001); + assertEquals(50, req.maxTokens()); + } + + @Test + void testPartialClientFieldsWithPartialFallback() { + final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.1}"; + final AiCommonConfig config = new AiCommonConfig(); + config.setModel("admin-model"); + config.setTemperature(0.9); + config.setMaxTokens(2048); + + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); + assertEquals("client-model", req.model()); + assertEquals(0.1, req.temperature(), 0.001); + assertEquals(2048, req.maxTokens()); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java index 687c51233f95..d9e72c2b67f5 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java @@ -21,25 +21,24 @@ import org.apache.shenyu.common.dto.RuleData; import org.apache.shenyu.common.dto.SelectorData; import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; -import org.apache.shenyu.common.enums.AiModelProviderEnum; import org.apache.shenyu.common.enums.PluginEnum; import org.apache.shenyu.common.utils.JsonUtils; import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; -import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry; +import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; +import org.apache.shenyu.plugin.ai.proxy.enhanced.service.UpstreamErrorLogger; import org.apache.shenyu.plugin.api.ShenyuPluginChain; import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils; import org.apache.shenyu.plugin.base.AbstractShenyuPlugin; import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.http.HttpHeaders; @@ -65,26 +64,18 @@ public class AiProxyPlugin extends AbstractShenyuPlugin { */ private static final long MAX_REQUEST_BODY_SIZE_BYTES = 5 * 1024 * 1024L; - private final AiModelFactoryRegistry aiModelFactoryRegistry; - private final AiProxyConfigService aiProxyConfigService; private final AiProxyExecutorService aiProxyExecutorService; - private final ChatClientCache chatClientCache; - private final AiProxyPluginHandler aiProxyPluginHandler; public AiProxyPlugin( - final AiModelFactoryRegistry aiModelFactoryRegistry, final AiProxyConfigService aiProxyConfigService, final AiProxyExecutorService aiProxyExecutorService, - final ChatClientCache chatClientCache, final AiProxyPluginHandler aiProxyPluginHandler) { - this.aiModelFactoryRegistry = aiModelFactoryRegistry; this.aiProxyConfigService = aiProxyConfigService; this.aiProxyExecutorService = aiProxyExecutorService; - this.chatClientCache = chatClientCache; this.aiProxyPluginHandler = aiProxyPluginHandler; } @@ -148,7 +139,7 @@ protected Mono doExecute( } } - if (Boolean.TRUE.equals(primaryConfig.getStream())) { + if (OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream())) { return handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); } return handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); @@ -161,23 +152,26 @@ private Mono handleStreamRequest( final String requestBody, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle) { - final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig); - final String prompt = aiProxyConfigService.extractPrompt(requestBody); - final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle, - selector.getId(), requestBody); + final OpenAiApi mainApi = createOpenAiApi(primaryConfig); + final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig); + final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle, + requestBody); final ServerHttpResponse response = exchange.getResponse(); response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM); - final Flux chatResponseFlux = aiProxyExecutorService.executeStream(mainClient, fallbackClient, - prompt); + final Flux chunkFlux = aiProxyExecutorService.executeDirectStream( + mainApi, fallbackApi, request); - final Flux sseFlux = chatResponseFlux.map( - chatResponse -> { - final String json = JsonUtils.toJson(chatResponse); + final Flux sseFlux = chunkFlux.map( + chunk -> { + final String json = JsonUtils.toJson(chunk); final String sseData = "data: " + json + "\n\n"; return response.bufferFactory() .wrap(sseData.getBytes(StandardCharsets.UTF_8)); - }); + }).concatWith(Mono.fromSupplier(() -> + response.bufferFactory() + .wrap("data: [DONE]\n\n".getBytes(StandardCharsets.UTF_8)))) + .doOnError(e -> logUpstreamError(e, "stream")); return response.writeWith(sseFlux); } @@ -188,24 +182,25 @@ private Mono handleNonStreamRequest( final String requestBody, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle) { - final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig); - final String prompt = aiProxyConfigService.extractPrompt(requestBody); - final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle, - selector.getId(), requestBody); + final OpenAiApi mainApi = createOpenAiApi(primaryConfig); + final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig); + final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle, + requestBody); return aiProxyExecutorService - .execute(mainClient, fallbackClient, prompt) + .executeDirectCall(mainApi, fallbackApi, request) .flatMap( - response -> { - byte[] jsonBytes = JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8); + responseEntity -> { + final String responseJson = JsonUtils.toJson(responseEntity.getBody()); + byte[] jsonBytes = responseJson.getBytes(StandardCharsets.UTF_8); return WebFluxResultUtils.result(exchange, jsonBytes); - }); + }) + .doOnError(e -> logUpstreamError(e, "non-stream")); } - private Optional resolveFallbackClient( + private Optional resolveFallbackOpenAiApi( final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle, - final String selectorId, final String requestBody) { return aiProxyConfigService .resolveDynamicFallbackConfig(primaryConfig, requestBody) @@ -214,86 +209,30 @@ private Optional resolveFallbackClient( if (LOG.isDebugEnabled()) { LOG.debug("[AiProxy] dynamic fallback config: {}", cfg); } - return createDynamicFallbackClient(cfg); + return createOpenAiApi(cfg); }) - .or( - () -> aiProxyConfigService - .resolveAdminFallbackConfig(primaryConfig, selectorHandle) - .map(adminFallbackConfig -> { - LOG.info("[AiProxy] use admin fallback"); - if (LOG.isDebugEnabled()) { - LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig); - } - return createAdminFallbackClient(selectorId, adminFallbackConfig); - })); - } - - /** - * Generate cache key based on config fields excluding apiKey. - * This ensures cache consistency even when apiKey is updated at runtime. - * - * @param config the config - * @return cache key hash - */ - private int generateConfigCacheKey(final AiCommonConfig config) { - return Objects.hash( - config.getProvider(), - config.getBaseUrl(), - config.getModel(), - config.getTemperature(), - config.getMaxTokens(), - config.getStream() - // Explicitly exclude apiKey to avoid cache misses when apiKey changes - ); - } - - private ChatClient createMainChatClient(final String selectorId, final AiCommonConfig config) { - final int configHash = generateConfigCacheKey(config); - final String cacheKey = selectorId + "|main_" + configHash; - return chatClientCache.computeIfAbsent( - cacheKey, - () -> { - LOG.info("Creating and caching main model for selector: {}, key: {}", selectorId, cacheKey); - return createChatModel(config); - }); - } - - private ChatClient createAdminFallbackClient( - final String selectorId, final AiCommonConfig fallbackConfig) { - final int configHash = generateConfigCacheKey(fallbackConfig); - final String fallbackCacheKey = selectorId + "|adminFallback_" + configHash; - return chatClientCache.computeIfAbsent( - fallbackCacheKey, - () -> { - LOG.info( - "Creating and caching admin fallback model for selector: {}, key: {}", - selectorId, fallbackCacheKey); - return createChatModel(fallbackConfig); - }); - } - - private ChatClient createDynamicFallbackClient(final AiCommonConfig fallbackConfig) { - LOG.info("Creating non-cached dynamic fallback model."); - return ChatClient.builder(createChatModel(fallbackConfig)).build(); + .or(() -> aiProxyConfigService + .resolveAdminFallbackConfig(primaryConfig, selectorHandle) + .map(adminFallbackConfig -> { + LOG.info("[AiProxy] use admin fallback"); + if (LOG.isDebugEnabled()) { + LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig); + } + return createOpenAiApi(adminFallbackConfig); + })); } - private ChatModel createChatModel(final AiCommonConfig config) { - if (LOG.isDebugEnabled()) { - LOG.debug("Creating chat model with config: {}", config); - } - final AiModelProviderEnum provider = AiModelProviderEnum.getByName(config.getProvider()); - if (Objects.isNull(provider)) { - throw new IllegalArgumentException( - "Invalid AI model provider in config: " + config.getProvider()); + private OpenAiApi createOpenAiApi(final AiCommonConfig config) { + if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) { + throw new IllegalArgumentException("baseUrl must not be empty"); } - final var factory = aiModelFactoryRegistry.getFactory(provider); - if (Objects.isNull(factory)) { - throw new IllegalArgumentException( - "AI model factory not found for provider: " + provider.getName()); + if (Objects.isNull(config.getApiKey()) || config.getApiKey().isEmpty()) { + throw new IllegalArgumentException("apiKey must not be empty"); } - return Objects.requireNonNull( - factory.createAiModel(config), - "The AI model created by the factory must not be null"); + return OpenAiApi.builder() + .baseUrl(config.getBaseUrl()) + .apiKey(config.getApiKey()) + .build(); } @Override @@ -305,4 +244,8 @@ public int getOrder() { public String named() { return PluginEnum.AI_PROXY.getName(); } + + private void logUpstreamError(final Throwable e, final String mode) { + UpstreamErrorLogger.logUpstreamError(LOG, e, mode); + } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java deleted file mode 100644 index 473ace98fa91..000000000000 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.shenyu.plugin.ai.proxy.enhanced.cache; - -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; - -/** - * This is ChatClient cache. - */ -public final class ChatClientCache { - - private static final Logger LOG = LoggerFactory.getLogger(ChatClientCache.class); - - private static final int MAX_CACHE_SIZE = getCacheSize(); - - private final Map chatClientMap = new ConcurrentHashMap<>(); - - private final AtomicBoolean evictionInProgress = new AtomicBoolean(false); - - /** - * Instantiates a new Chat client cache. - */ - public ChatClientCache() { - } - - private static int getCacheSize() { - String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize", - System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE")); - if (Objects.nonNull(value)) { - try { - return Integer.parseInt(value); - } catch (NumberFormatException e) { - LoggerFactory.getLogger(ChatClientCache.class) - .warn("[ChatClientCache] Invalid cache size '{}', using default 500.", value); - } - } - return 500; - } - - /** - * Gets client or compute if absent. - * - * @param key the key - * @param chatModelSupplier the chat model supplier - * @return the chat client - */ - public ChatClient computeIfAbsent(final String key, final Supplier chatModelSupplier) { - // Check size before computing, but use synchronized block to prevent race conditions - final int currentSize = chatClientMap.size(); - if (currentSize > MAX_CACHE_SIZE) { - // Use atomic flag to ensure only one thread performs eviction - if (evictionInProgress.compareAndSet(false, true)) { - try { - synchronized (chatClientMap) { - // Double-check after acquiring lock - if (chatClientMap.size() > MAX_CACHE_SIZE) { - evictOldestEntries(); - } - } - } finally { - evictionInProgress.set(false); - } - } - } - return chatClientMap.computeIfAbsent(key, k -> ChatClient.builder(chatModelSupplier.get()).build()); - } - - /** - * Evict oldest entries when cache size exceeds limit. - * Removes approximately 25% of entries to avoid thundering herd problem. - */ - private void evictOldestEntries() { - final int currentSize = chatClientMap.size(); - if (currentSize <= MAX_CACHE_SIZE) { - return; - } - - // Evict 25% of entries, but at least 10 entries - final int evictCount = Math.max(10, currentSize / 4); - LOG.warn("[ChatClientCache] Cache size {} exceeded limit {}, evicting {} oldest entries", - currentSize, MAX_CACHE_SIZE, evictCount); - - // Since ConcurrentHashMap doesn't maintain insertion order, - // we evict entries based on iteration order (which is somewhat arbitrary but better than clearing all) - int removed = 0; - for (final String key : chatClientMap.keySet()) { - if (removed >= evictCount) { - break; - } - chatClientMap.remove(key); - removed++; - } - - LOG.info("[ChatClientCache] Evicted {} entries, cache size now: {}", removed, chatClientMap.size()); - } - - /** - * Removes all cached clients associated with a selector ID (by prefix matching - * "selectorId|"). - * - * @param selectorId the selector id - */ - public void remove(final String selectorId) { - if (java.util.Objects.isNull(selectorId)) { - return; - } - final String prefix = selectorId + "|"; - chatClientMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix)); - LOG.info("[ChatClientCache] invalidate selectorId={} (by prefix)", selectorId); - } - - /** - * Clear all cached clients. - */ - public void clearAll() { - chatClientMap.clear(); - LOG.info("[ChatClientCache] cleared all cached clients"); - } -} \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java index e4e3f470ed7c..8b8868a04560 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java @@ -23,7 +23,6 @@ import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; import org.apache.shenyu.common.enums.PluginEnum; import org.apache.shenyu.common.utils.GsonUtils; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.apache.shenyu.plugin.base.cache.CommonHandleCache; import org.apache.shenyu.plugin.base.handler.PluginDataHandler; import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; @@ -37,12 +36,6 @@ public class AiProxyPluginHandler implements PluginDataHandler { private final CommonHandleCache selectorCachedHandle = new CommonHandleCache<>(); - private final ChatClientCache chatClientCache; - - public AiProxyPluginHandler(final ChatClientCache chatClientCache) { - this.chatClientCache = chatClientCache; - } - @Override public void handlerPlugin(final PluginData pluginData) { // Note: The logic for handling global plugin configuration with Singleton has been removed @@ -52,10 +45,6 @@ public void handlerPlugin(final PluginData pluginData) { @Override public void handlerSelector(final SelectorData selectorData) { - // Invalidate the cache first when the selector is updated. - chatClientCache.remove(selectorData.getId()); - // Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events - // with refreshed realApiKey after selector changes. Removing here introduces a window of misses. if (Objects.isNull(selectorData.getHandle())) { return; } @@ -67,8 +56,6 @@ public void handlerSelector(final SelectorData selectorData) { @Override public void removeSelector(final SelectorData selectorData) { - // Invalidate the cache when the selector is removed. - chatClientCache.remove(selectorData.getId()); selectorCachedHandle .removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE)); } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java index 71ab0c4c49d1..8502a772f43e 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java @@ -17,12 +17,14 @@ package org.apache.shenyu.plugin.ai.proxy.enhanced.service; -import org.apache.shenyu.plugin.ai.common.strategy.SimpleModelFallbackStrategy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.NonTransientAiException; +import org.springframework.http.ResponseEntity; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -39,78 +41,77 @@ public class AiProxyExecutorService { private static final Logger LOG = LoggerFactory.getLogger(AiProxyExecutorService.class); /** - * Execute the AI call with retry and fallback. + * Execute a streaming AI call directly via {@link OpenAiApi}, bypassing Spring AI's + * {@code createRequest()} which loses fields like {@code reasoning_content}. * - * @param mainClient the main chat client - * @param fallbackClientOpt the optional fallback chat client - * @param requestBody the request body - * @return a Mono containing the ChatResponse + * @param mainApi the main OpenAiApi + * @param fallbackApiOpt the optional fallback OpenAiApi + * @param request the ChatCompletionRequest with all fields preserved + * @return a Flux of ChatCompletionChunk */ - public Mono execute(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) { - final Mono mainCall = doChatCall(mainClient, requestBody); - - return mainCall - .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)) - .filter(throwable -> !(throwable instanceof NonTransientAiException)) + public Flux executeDirectStream(final OpenAiApi mainApi, + final Optional fallbackApiOpt, final ChatCompletionRequest request) { + return mainApi.chatCompletionStream(request) + .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct stream")) + .retryWhen(Retry.max(1) .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> { - LOG.warn("Retries exhausted for AI call after {} attempts.", - retrySignal.totalRetries(), retrySignal.failure()); - return new NonTransientAiException("Retries exhausted. Triggering fallback.", + LOG.warn("Direct stream retry exhausted. Triggering fallback.", + retrySignal.failure()); + return new NonTransientAiException( + "Direct stream failed after 1 retry. Triggering fallback.", retrySignal.failure()); })) .onErrorResume(NonTransientAiException.class, - throwable -> handleFallback(throwable, fallbackClientOpt, requestBody)); - } - - protected Mono doChatCall(final ChatClient client, final String requestBody) { - return Mono.fromCallable(() -> client.prompt().user(requestBody).call().chatResponse()) - .subscribeOn(Schedulers.boundedElastic()); - } - - private Mono handleFallback(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) { - LOG.warn("AI main call failed or retries exhausted, attempting to fallback...", throwable); - - if (fallbackClientOpt.isEmpty()) { - return Mono.error(throwable); - } - - return SimpleModelFallbackStrategy.INSTANCE.fallback(fallbackClientOpt.get(), requestBody, throwable); + throwable -> handleDirectFallbackStream(throwable, fallbackApiOpt, request)); } /** - * Execute the AI call with retry and fallback. + * Execute a non-streaming AI call directly via {@link OpenAiApi}. * - * @param mainClient the main chat client - * @param fallbackClientOpt the optional fallback chat client - * @param requestBody the request body - * @return a Flux containing the ChatResponse + * @param mainApi the main OpenAiApi + * @param fallbackApiOpt the optional fallback OpenAiApi + * @param request the ChatCompletionRequest with all fields preserved + * @return a Mono of ResponseEntity containing ChatCompletion */ - public Flux executeStream(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) { - final Flux mainStream = doChatStream(mainClient, requestBody); - - return mainStream - .retryWhen(Retry.max(1) + public Mono> executeDirectCall(final OpenAiApi mainApi, + final Optional fallbackApiOpt, final ChatCompletionRequest request) { + return Mono.fromCallable(() -> mainApi.chatCompletionEntity(request)) + .subscribeOn(Schedulers.boundedElastic()) + .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct call")) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)) + .filter(throwable -> !(throwable instanceof NonTransientAiException)) .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> { - LOG.warn("Retrying stream once failed. Attempts: {}. Triggering fallback.", + LOG.warn("Direct call retries exhausted after {} attempts. Triggering fallback.", retrySignal.totalRetries(), retrySignal.failure()); - return new NonTransientAiException("Stream failed after 1 retry. Triggering fallback.", retrySignal.failure()); + return new NonTransientAiException("Direct call retries exhausted. Triggering fallback.", + retrySignal.failure()); })) .onErrorResume(NonTransientAiException.class, - throwable -> handleFallbackStream(throwable, fallbackClientOpt, requestBody)); + throwable -> handleDirectFallbackCall(throwable, fallbackApiOpt, request)); } - protected Flux doChatStream(final ChatClient client, final String requestBody) { - return Flux.defer(() -> client.prompt().user(requestBody).stream().chatResponse()) - .subscribeOn(Schedulers.boundedElastic()); + private Flux handleDirectFallbackStream(final Throwable throwable, + final Optional fallbackApiOpt, final ChatCompletionRequest request) { + LOG.warn("Main direct stream failed, attempting fallback...", throwable); + + if (fallbackApiOpt.isEmpty()) { + return Flux.error(throwable); + } + + LOG.info("Using fallback OpenAiApi for direct stream"); + return fallbackApiOpt.get().chatCompletionStream(request); } - private Flux handleFallbackStream(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) { - LOG.warn("AI main stream failed or retries exhausted, attempting to fallback...", throwable); + private Mono> handleDirectFallbackCall(final Throwable throwable, + final Optional fallbackApiOpt, final ChatCompletionRequest request) { + LOG.warn("Main direct call failed, attempting fallback...", throwable); - if (fallbackClientOpt.isEmpty()) { - return Flux.error(throwable); + if (fallbackApiOpt.isEmpty()) { + return Mono.error(throwable); } - return SimpleModelFallbackStrategy.INSTANCE.fallbackStream(fallbackClientOpt.get(), requestBody, throwable); + LOG.info("Using fallback OpenAiApi for direct call"); + return Mono.fromCallable(() -> fallbackApiOpt.get().chatCompletionEntity(request)) + .subscribeOn(Schedulers.boundedElastic()); } -} \ No newline at end of file +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java new file mode 100644 index 000000000000..bee49aab7d0c --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.service; + +import org.slf4j.Logger; +import org.springframework.web.reactive.function.client.WebClientResponseException; + +import java.util.Objects; + +/** + * Shared utility for logging upstream AI service errors with WebClientResponseException details. + */ +public final class UpstreamErrorLogger { + + private UpstreamErrorLogger() { + } + + public static void logUpstreamError(final Logger log, final Throwable e, final String mode) { + final WebClientResponseException webClientEx = findWebClientResponseException(e); + if (Objects.nonNull(webClientEx)) { + log.error("[AiProxy] {} failed, status={}, upstreamBody={}", + mode, webClientEx.getStatusCode(), webClientEx.getResponseBodyAsString(), e); + } else { + log.error("[AiProxy] {} failed", mode, e); + } + } + + private static WebClientResponseException findWebClientResponseException(final Throwable e) { + Throwable current = e; + while (Objects.nonNull(current)) { + if (current instanceof WebClientResponseException ex) { + return ex; + } + current = current.getCause(); + } + return null; + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java index 4f63f5b53bb8..d25434881658 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java @@ -19,7 +19,6 @@ import org.apache.shenyu.common.dto.ProxyApiKeyData; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.apache.shenyu.sync.data.api.AiProxyApiKeyDataSubscriber; import java.util.Objects; @@ -29,12 +28,6 @@ */ public final class CommonAiProxyApiKeyDataSubscriber implements AiProxyApiKeyDataSubscriber { - private final ChatClientCache chatClientCache; - - public CommonAiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) { - this.chatClientCache = chatClientCache; - } - @Override public void onSubscribe(final ProxyApiKeyData data) { if (Objects.isNull(data) || Objects.isNull(data.getProxyApiKey())) { @@ -54,6 +47,5 @@ public void unSubscribe(final ProxyApiKeyData data) { @Override public void refresh() { AiProxyApiKeyCache.getInstance().refresh(); - chatClientCache.clearAll(); } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java index 30e84abc0809..dfbf95a1b0c9 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java @@ -23,10 +23,7 @@ import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; import org.apache.shenyu.common.enums.AiModelProviderEnum; import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; -import org.apache.shenyu.plugin.ai.common.spring.ai.AiModelFactory; -import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; @@ -43,21 +40,22 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.context.ApplicationContext; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.util.Optional; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -75,27 +73,12 @@ public class AiProxyPluginTest { private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}"; - @Mock - private AiModelFactoryRegistry registry; - @Mock private AiProxyConfigService configService; @Mock private AiProxyExecutorService executorService; - @Mock - private ChatClientCache chatClientCache; - - @Mock - private AiModelFactory modelFactory; - - @Mock - private ChatModel chatModel; - - @Mock - private ChatClient chatClient; - private AiProxyPluginHandler aiProxyPluginHandler; private AiProxyPlugin plugin; @@ -110,8 +93,8 @@ public class AiProxyPluginTest { @BeforeEach public void setUp() { - aiProxyPluginHandler = new AiProxyPluginHandler(chatClientCache); - plugin = new AiProxyPlugin(registry, configService, executorService, chatClientCache, aiProxyPluginHandler); + aiProxyPluginHandler = new AiProxyPluginHandler(); + plugin = new AiProxyPlugin(configService, executorService, aiProxyPluginHandler); selector = new SelectorData(); selector.setId(SELECTOR_ID); @@ -129,14 +112,6 @@ public void setUp() { when(applicationContext.getBean(ShenyuResult.class)).thenReturn(resultMock); SpringBeanUtils.getInstance().setApplicationContext(applicationContext); - // Common mock behavior for model creation and cache - when(registry.getFactory(any(AiModelProviderEnum.class))).thenReturn(modelFactory); - when(modelFactory.createAiModel(any(AiCommonConfig.class))).thenReturn(chatModel); - when(chatClientCache.computeIfAbsent(anyString(), any())).thenAnswer(invocation -> { - // Always return the mock ChatClient to avoid any UnsupportedOperationException - return chatClient; - }); - // mock static apiKeyCacheMockedStatic = mockStatic(AiProxyApiKeyCache.class); } @@ -150,13 +125,14 @@ public void tearDown() { private void setupSuccessMocks(final AiProxyHandle handle, final AiCommonConfig primaryConfig, final Optional fallbackConfig) { aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle); - final ChatResponse chatResponse = mock(ChatResponse.class); + final ChatCompletion chatCompletion = mock(ChatCompletion.class); + final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion); when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(fallbackConfig); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(fallbackConfig); - when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0)); - when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity)); + when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Flux.empty()); } @Test @@ -164,6 +140,8 @@ public void testExecuteSuccess() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); setupSuccessMocks(handle, primaryConfig, Optional.empty()); StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) @@ -173,7 +151,7 @@ public void testExecuteSuccess() { verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).execute(any(), any(), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); } @Test @@ -181,8 +159,12 @@ public void testExecuteWithDynamicFallback() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final AiCommonConfig fallbackConfig = new AiCommonConfig(); fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName()); + fallbackConfig.setBaseUrl("https://api.deepseek.com"); + fallbackConfig.setApiKey("fallback-key"); setupSuccessMocks(handle, primaryConfig, Optional.of(fallbackConfig)); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.of(fallbackConfig)); @@ -190,7 +172,7 @@ public void testExecuteWithDynamicFallback() { StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - verify(executorService).execute(any(ChatClient.class), any(Optional.class), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); } @Test @@ -199,6 +181,7 @@ public void testExecuteWithValidProxyApiKey() { handle.setProxyEnabled("true"); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); primaryConfig.setApiKey("original-key"); // setup request with proxy key @@ -225,6 +208,8 @@ public void testExecuteWithInvalidProxyApiKey() { handle.setProxyEnabled("true"); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); // setup request with proxy key exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/test").header(Constants.X_API_KEY, "proxy-key-invalid").body(REQUEST_BODY)); @@ -252,8 +237,12 @@ public void testExecuteWithAdminFallback() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final AiCommonConfig fallbackConfig = new AiCommonConfig(); fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName()); + fallbackConfig.setBaseUrl("https://api.deepseek.com"); + fallbackConfig.setApiKey("fallback-key"); setupSuccessMocks(handle, primaryConfig, Optional.empty()); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty()); @@ -262,7 +251,7 @@ public void testExecuteWithAdminFallback() { StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - verify(executorService).execute(any(ChatClient.class), any(Optional.class), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); } @Test @@ -270,53 +259,49 @@ public void testCacheIsUsedForAdminFallbackClient() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final AiCommonConfig fallbackConfig = new AiCommonConfig(); fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName()); + fallbackConfig.setBaseUrl("https://api.deepseek.com"); + fallbackConfig.setApiKey("fallback-key"); // Cache the handle for the test aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle); - final ChatResponse chatResponse = mock(ChatResponse.class); + final ChatCompletion chatCompletion = mock(ChatCompletion.class); + final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion); // Setup all necessary mocks when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty()); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig)); - when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0)); - when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity)); // Execute the test - focus on successful execution rather than cache verification StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - + // Verify that the configuration methods were called correctly verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).execute(any(), any(), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); } @Test - public void testCreateChatModelThrowsException() { + public void testCreateOpenAiApiThrowsException() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); - primaryConfig.setProvider("InvalidProvider"); - - // Cache the handle for the test + primaryConfig.setBaseUrl(null); + aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle); - - // Mock config service to return the invalid config + when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty()); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.empty()); - when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0)); - - // Mock registry to return null factory for invalid provider - this should cause IllegalArgumentException - when(registry.getFactory(any())).thenReturn(null); - - // Mock executorService to return a proper Mono to avoid NullPointerException - when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(new IllegalArgumentException("AI model factory not found"))); StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) - .expectError(IllegalArgumentException.class) + .expectErrorMatches(e -> e instanceof IllegalArgumentException + && e.getMessage().contains("baseUrl must not be empty")) .verify(); } @@ -325,10 +310,12 @@ public void testExecutorServiceError() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); + primaryConfig.setBaseUrl("https://api.openai.com"); + primaryConfig.setApiKey("test-key"); final RuntimeException exception = new RuntimeException("AI execution failed"); setupSuccessMocks(handle, primaryConfig, Optional.empty()); - when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(exception)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.error(exception)); StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .expectErrorMatches(exception::equals) @@ -337,6 +324,6 @@ public void testExecutorServiceError() { verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).execute(any(), any(), any()); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); } } \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java index 7bd2bbbdac0d..46d3719cab2a 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java @@ -20,21 +20,19 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.NonTransientAiException; -import org.springframework.web.client.RestClientException; +import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; import java.util.Optional; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -42,102 +40,96 @@ @ExtendWith(MockitoExtension.class) public class AiProxyExecutorServiceTest { - @Mock - private ChatModel mainChatModel; - - @Mock - private ChatModel fallbackChatModel; - - private ChatClient mainClient; - - private ChatClient fallbackClient; - private AiProxyExecutorService executorService; @BeforeEach void setUp() { - executorService = spy(new AiProxyExecutorService()); - mainClient = ChatClient.create(mainChatModel); - fallbackClient = ChatClient.create(fallbackChatModel); + executorService = new AiProxyExecutorService(); } @Test - void testExecuteSuccessOnFirstAttempt() { - final ChatResponse successResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))).thenReturn(successResponse); - - StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request")) - .expectNext(successResponse) + void testExecuteDirectStreamSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionChunk chunk = mock(ChatCompletionChunk.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + when(mainApi.chatCompletionStream(request)).thenReturn(Flux.just(chunk)); + + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request)) + .expectNext(chunk) .verifyComplete(); - verify(mainChatModel, times(1)).call(any(Prompt.class)); + verify(mainApi, times(1)).chatCompletionStream(request); } @Test - void testExecuteRetryOnRestClientExceptionAndSucceed() { - final ChatResponse successResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))) - .thenThrow(new RestClientException("transient error")) - .thenReturn(successResponse); - - StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request")) - .expectNext(successResponse) - .verifyComplete(); + void testExecuteDirectStreamErrorWithNoFallback() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error"))); - verify(mainChatModel, times(2)).call(any(Prompt.class)); + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request)) + .expectError(NonTransientAiException.class) + .verify(); } @Test - void testExecuteRetryExhaustedThenFallback() { - final ChatResponse fallbackResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))).thenThrow(new RestClientException("transient error")); - when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse); + void testExecuteDirectStreamFallbackSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final OpenAiApi fallbackApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + final ChatCompletionChunk fallbackChunk = mock(ChatCompletionChunk.class); - StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request")) - .expectNext(fallbackResponse) + when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error"))); + when(fallbackApi.chatCompletionStream(request)).thenReturn(Flux.just(fallbackChunk)); + + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(fallbackApi), request)) + .expectNext(fallbackChunk) .verifyComplete(); - verify(mainChatModel, times(4)).call(any(Prompt.class)); - verify(fallbackChatModel, times(1)).call(any(Prompt.class)); + verify(fallbackApi, times(1)).chatCompletionStream(request); } @Test - void testExecuteNonTransientExceptionTriggersFallbackDirectly() { - final ChatResponse fallbackResponse = mock(ChatResponse.class); - when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error")); - when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse); - - StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request")) - .expectNext(fallbackResponse) + void testExecuteDirectCallSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + final ChatCompletion completion = mock(ChatCompletion.class); + final ResponseEntity responseEntity = ResponseEntity.ok(completion); + when(mainApi.chatCompletionEntity(request)).thenReturn(responseEntity); + + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request)) + .expectNext(responseEntity) .verifyComplete(); - verify(mainChatModel, times(1)).call(any(Prompt.class)); - verify(fallbackChatModel, times(1)).call(any(Prompt.class)); + verify(mainApi, times(1)).chatCompletionEntity(request); } @Test - void testExecuteFallbackFails() { - final RestClientException fallbackException = new RestClientException("fallback failed"); - when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error")); - when(fallbackChatModel.call(any(Prompt.class))).thenThrow(fallbackException); + void testExecuteDirectCallErrorWithNoFallback() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error")); - StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request")) - .expectErrorMatches(e -> e == fallbackException) + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request)) + .expectError(NonTransientAiException.class) .verify(); - - verify(mainChatModel, times(1)).call(any(Prompt.class)); - verify(fallbackChatModel, times(1)).call(any(Prompt.class)); } @Test - void testExecuteNoFallbackProvided() { - final NonTransientAiException mainException = new NonTransientAiException("non-transient error"); - when(mainChatModel.call(any(Prompt.class))).thenThrow(mainException); + void testExecuteDirectCallFallbackSuccess() { + final OpenAiApi mainApi = mock(OpenAiApi.class); + final OpenAiApi fallbackApi = mock(OpenAiApi.class); + final ChatCompletionRequest request = mock(ChatCompletionRequest.class); + final ChatCompletion fallbackCompletion = mock(ChatCompletion.class); + final ResponseEntity fallbackResponse = ResponseEntity.ok(fallbackCompletion); - StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request")) - .expectErrorMatches(e -> e == mainException) - .verify(); + when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error")); + when(fallbackApi.chatCompletionEntity(request)).thenReturn(fallbackResponse); + + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(fallbackApi), request)) + .expectNext(fallbackResponse) + .verifyComplete(); - verify(mainChatModel, times(1)).call(any(Prompt.class)); + verify(fallbackApi, times(1)).chatCompletionEntity(request); } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java index c863ba20f48a..4827c5337240 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java @@ -19,17 +19,12 @@ import org.apache.shenyu.common.dto.ProxyApiKeyData; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import static org.mockito.Mockito.mock; - class CommonAiProxyApiKeyDataSubscriberTest { - private final ChatClientCache chatClientCache = mock(ChatClientCache.class); - @AfterEach void cleanup() { AiProxyApiKeyCache.getInstance().refresh(); @@ -37,7 +32,7 @@ void cleanup() { @Test void testOnSubscribeCachesEnabled() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-sub") @@ -51,7 +46,7 @@ void testOnSubscribeCachesEnabled() { @Test void testOnSubscribeIgnoresDisabled() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-disabled") @@ -65,7 +60,7 @@ void testOnSubscribeIgnoresDisabled() { @Test void testUnSubscribeRemoves() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-unsub") @@ -81,7 +76,7 @@ void testUnSubscribeRemoves() { @Test void testRefreshClearsCache() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("proxy-refresh") @@ -98,14 +93,14 @@ void testRefreshClearsCache() { @Test void testOnSubscribeNullDataNoThrow() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); subscriber.onSubscribe(null); Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size()); } @Test void testOnSubscribeNullKeyIgnored() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey(null) @@ -119,14 +114,14 @@ void testOnSubscribeNullKeyIgnored() { @Test void testUnSubscribeNullDataNoThrow() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); subscriber.unSubscribe(null); Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size()); } @Test void testUnSubscribeNullKeyNoThrow() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData data = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey(null) @@ -140,7 +135,7 @@ void testUnSubscribeNullKeyNoThrow() { @Test void testDuplicateSubscribeOverrides() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(); ProxyApiKeyData v1 = ProxyApiKeyData.builder() .selectorId("test-selector") .proxyApiKey("dup-key") @@ -161,48 +156,4 @@ void testDuplicateSubscribeOverrides() { subscriber.onSubscribe(v2); Assertions.assertEquals("real-2", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "dup-key")); } - - @Test - void testSubscribeDisabledDoesNotOverrideExisting() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); - ProxyApiKeyData enabled = ProxyApiKeyData.builder() - .selectorId("test-selector") - .proxyApiKey("keep-key") - .realApiKey("real-keep") - .enabled(Boolean.TRUE) - .namespaceId("default") - .build(); - subscriber.onSubscribe(enabled); - Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key")); - - ProxyApiKeyData disabled = ProxyApiKeyData.builder() - .selectorId("test-selector") - .proxyApiKey("keep-key") - .realApiKey("real-new") - .enabled(Boolean.FALSE) - .namespaceId("default") - .build(); - subscriber.onSubscribe(disabled); - // disabled subscribe should not override existing cached mapping - Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key")); - } - - @Test - void testVeryLongProxyKey() { - CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache); - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < 300; i++) { - sb.append('A' + (i % 26)); - } - String longKey = sb.toString(); - ProxyApiKeyData data = ProxyApiKeyData.builder() - .selectorId("test-selector") - .proxyApiKey(longKey) - .realApiKey("real-long") - .enabled(Boolean.TRUE) - .namespaceId("default") - .build(); - subscriber.onSubscribe(data); - Assertions.assertEquals("real-long", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", longKey)); - } -} \ No newline at end of file +} diff --git a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java index d017cfabd8fc..2bcb70aa1dae 100644 --- a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java +++ b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java @@ -22,7 +22,6 @@ import org.apache.shenyu.plugin.ai.common.spring.ai.factory.OpenAiModelFactory; import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry; import org.apache.shenyu.plugin.ai.proxy.enhanced.AiProxyPlugin; -import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; @@ -48,42 +47,30 @@ public class AiProxyPluginConfiguration { /** * Ai proxy plugin. * - * @param aiModelFactoryRegistry the aiModelFactoryRegistry * @param aiProxyConfigService the aiProxyConfigService * @param aiProxyExecutorService the aiProxyExecutorService - * @param chatClientCache the chatClientCache * @param aiProxyPluginHandler the aiProxyPluginHandler * @return the shenyu plugin */ @Bean public ShenyuPlugin aiProxyPlugin( - final AiModelFactoryRegistry aiModelFactoryRegistry, final AiProxyConfigService aiProxyConfigService, final AiProxyExecutorService aiProxyExecutorService, - final ChatClientCache chatClientCache, final AiProxyPluginHandler aiProxyPluginHandler) { return new AiProxyPlugin( - aiModelFactoryRegistry, aiProxyConfigService, aiProxyExecutorService, - chatClientCache, aiProxyPluginHandler); } /** * Ai proxy plugin handler. * - * @param chatClientCache the chatClientCache * @return the shenyu plugin handler */ @Bean - public AiProxyPluginHandler aiProxyPluginHandler(final ChatClientCache chatClientCache) { - return new AiProxyPluginHandler(chatClientCache); - } - - @Bean - public ChatClientCache chatClientCache() { - return new ChatClientCache(); + public AiProxyPluginHandler aiProxyPluginHandler() { + return new AiProxyPluginHandler(); } @Bean @@ -131,11 +118,10 @@ public DeepSeekModelFactory deepSeekModelFactory() { /** * Ai proxy api key data subscriber. * - * @param chatClientCache the chatClientCache * @return the subscriber */ @Bean - public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) { - return new CommonAiProxyApiKeyDataSubscriber(chatClientCache); + public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber() { + return new CommonAiProxyApiKeyDataSubscriber(); } } From 69d8ff6b394b102c0264c9eb45245789df9936f9 Mon Sep 17 00:00:00 2001 From: eye-gu <734164350@qq.com> Date: Wed, 13 May 2026 23:28:18 +0800 Subject: [PATCH 2/6] add test --- .../handler/AiProxyPluginHandlerTest.java | 119 ++++++++++++++++++ .../service/UpstreamErrorLoggerTest.java | 73 +++++++++++ 2 files changed, 192 insertions(+) create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java new file mode 100644 index 000000000000..248ae16ac883 --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.handler; + +import org.apache.shenyu.common.constant.Constants; +import org.apache.shenyu.common.dto.SelectorData; +import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; +import org.apache.shenyu.common.enums.PluginEnum; +import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +class AiProxyPluginHandlerTest { + + private AiProxyPluginHandler handler; + + @BeforeEach + void setUp() { + handler = new AiProxyPluginHandler(); + } + + @Test + void testPluginNamed() { + assertEquals(PluginEnum.AI_PROXY.getName(), handler.pluginNamed()); + } + + @Test + void testHandlerSelectorCachesHandle() { + SelectorData selector = new SelectorData(); + selector.setId("sel-1"); + selector.setHandle("{\"provider\":\"open_ai\",\"baseUrl\":\"https://api.openai.com\",\"apiKey\":\"sk-test\",\"model\":\"gpt-4\"}"); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-1", Constants.DEFAULT_RULE); + AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key); + assertNotNull(cached); + assertEquals("open_ai", cached.getProvider()); + assertEquals("https://api.openai.com", cached.getBaseUrl()); + assertEquals("sk-test", cached.getApiKey()); + assertEquals("gpt-4", cached.getModel()); + } + + @Test + void testHandlerSelectorWithNullHandle() { + SelectorData selector = new SelectorData(); + selector.setId("sel-2"); + selector.setHandle(null); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-2", Constants.DEFAULT_RULE); + assertNull(handler.getSelectorCachedHandle().obtainHandle(key)); + } + + @Test + void testHandlerSelectorWithEmptyHandle() { + SelectorData selector = new SelectorData(); + selector.setId("sel-3"); + selector.setHandle(""); + + String key = CacheKeyUtils.INST.getKey("sel-3", Constants.DEFAULT_RULE); + assertNull(handler.getSelectorCachedHandle().obtainHandle(key)); + } + + @Test + void testHandlerSelectorWithFallbackNormalize() { + SelectorData selector = new SelectorData(); + selector.setId("sel-4"); + selector.setHandle("{\"provider\":\"open_ai\",\"fallbackEnabled\":\"true\",\"fallbackModel\":\"fallback-model\"}"); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-4", Constants.DEFAULT_RULE); + AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key); + assertNotNull(cached); + assertNotNull(cached.getFallbackConfig()); + assertEquals("fallback-model", cached.getFallbackConfig().getModel()); + } + + @Test + void testRemoveSelector() { + SelectorData selector = new SelectorData(); + selector.setId("sel-5"); + selector.setHandle("{\"provider\":\"open_ai\",\"model\":\"gpt-4\"}"); + + handler.handlerSelector(selector); + + String key = CacheKeyUtils.INST.getKey("sel-5", Constants.DEFAULT_RULE); + assertNotNull(handler.getSelectorCachedHandle().obtainHandle(key)); + + handler.removeSelector(selector); + assertNull(handler.getSelectorCachedHandle().obtainHandle(key)); + } + + @Test + void testHandlerPluginDoesNothing() { + handler.handlerPlugin(null); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java new file mode 100644 index 000000000000..46d21f43c9a1 --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.service; + +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.springframework.http.HttpStatus; +import org.springframework.web.reactive.function.client.WebClientResponseException; + +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +class UpstreamErrorLoggerTest { + + @Test + void testLogWithWebClientResponseException() { + Logger log = mock(Logger.class); + WebClientResponseException ex = WebClientResponseException.create( + 429, "Too Many Requests", null, "rate limited".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8); + + UpstreamErrorLogger.logUpstreamError(log, ex, "stream"); + + verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}", + "stream", HttpStatus.TOO_MANY_REQUESTS, "rate limited", ex); + } + + @Test + void testLogWithWrappedWebClientResponseException() { + Logger log = mock(Logger.class); + WebClientResponseException inner = WebClientResponseException.create( + 500, "Internal Server Error", null, "server error".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8); + RuntimeException wrapped = new RuntimeException("call failed", inner); + + UpstreamErrorLogger.logUpstreamError(log, wrapped, "non-stream"); + + verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}", + "non-stream", HttpStatus.INTERNAL_SERVER_ERROR, "server error", wrapped); + } + + @Test + void testLogWithGenericException() { + Logger log = mock(Logger.class); + RuntimeException ex = new RuntimeException("generic error"); + + UpstreamErrorLogger.logUpstreamError(log, ex, "stream"); + + verify(log).error("[AiProxy] {} failed", "stream", ex); + } + + @Test + void testLogWithNullExceptionDoesNotThrow() { + Logger log = mock(Logger.class); + assertDoesNotThrow(() -> UpstreamErrorLogger.logUpstreamError(log, null, "stream")); + } +} From b3aa313d61745b53a9da31cff31c69838191a130 Mon Sep 17 00:00:00 2001 From: eye-gu <734164350@qq.com> Date: Thu, 14 May 2026 18:19:51 +0800 Subject: [PATCH 3/6] use openai cache --- .../src/main/resources/application.yml | 2 +- .../ai/proxy/enhanced/AiProxyPlugin.java | 19 ++- .../proxy/enhanced/cache/OpenAiApiCache.java | 147 ++++++++++++++++++ .../handler/AiProxyPluginHandler.java | 12 ++ .../CommonAiProxyApiKeyDataSubscriber.java | 2 + 5 files changed, 175 insertions(+), 7 deletions(-) create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java diff --git a/shenyu-admin/src/main/resources/application.yml b/shenyu-admin/src/main/resources/application.yml index 6d2c3af3bfe2..47b4ab7cf3e2 100755 --- a/shenyu-admin/src/main/resources/application.yml +++ b/shenyu-admin/src/main/resources/application.yml @@ -21,7 +21,7 @@ spring: application: name: shenyu-admin profiles: - active: h2 + active: mysql thymeleaf: cache: true encoding: utf-8 diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java index d9e72c2b67f5..b1650f37f3f3 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java @@ -26,6 +26,7 @@ import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache; import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; @@ -152,9 +153,9 @@ private Mono handleStreamRequest( final String requestBody, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle) { - final OpenAiApi mainApi = createOpenAiApi(primaryConfig); + final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig); final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig); - final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle, + final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle, requestBody); final ServerHttpResponse response = exchange.getResponse(); response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM); @@ -182,9 +183,9 @@ private Mono handleNonStreamRequest( final String requestBody, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle) { - final OpenAiApi mainApi = createOpenAiApi(primaryConfig); + final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig); final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig); - final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle, + final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle, requestBody); return aiProxyExecutorService @@ -199,6 +200,7 @@ private Mono handleNonStreamRequest( } private Optional resolveFallbackOpenAiApi( + final String selectorId, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle, final String requestBody) { @@ -209,7 +211,7 @@ private Optional resolveFallbackOpenAiApi( if (LOG.isDebugEnabled()) { LOG.debug("[AiProxy] dynamic fallback config: {}", cfg); } - return createOpenAiApi(cfg); + return getCachedOpenAiApi(selectorId, "dynamicFallback", cfg); }) .or(() -> aiProxyConfigService .resolveAdminFallbackConfig(primaryConfig, selectorHandle) @@ -218,10 +220,15 @@ private Optional resolveFallbackOpenAiApi( if (LOG.isDebugEnabled()) { LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig); } - return createOpenAiApi(adminFallbackConfig); + return getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig); })); } + private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type, final AiCommonConfig config) { + final String cacheKey = selectorId + "|" + type + "_" + config.hashCode(); + return OpenAiApiCache.getInstance().computeIfAbsent(cacheKey, () -> createOpenAiApi(config)); + } + private OpenAiApi createOpenAiApi(final AiCommonConfig config) { if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) { throw new IllegalArgumentException("baseUrl must not be empty"); diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java new file mode 100644 index 000000000000..886b746ad94b --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.cache; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.openai.api.OpenAiApi; + +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +/** + * This is OpenAiApi cache. + * Caches OpenAiApi instances by config key to avoid recreating on every request. + */ +public final class OpenAiApiCache { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiApiCache.class); + + private static final OpenAiApiCache INSTANCE = new OpenAiApiCache(); + + private static final int MAX_CACHE_SIZE = getCacheSize(); + + private final Map openAiApiMap = new ConcurrentHashMap<>(); + + private final AtomicBoolean evictionInProgress = new AtomicBoolean(false); + + /** + * Instantiates a new OpenAiApi cache. + */ + public OpenAiApiCache() { + } + + /** + * Gets instance. + * + * @return singleton + */ + public static OpenAiApiCache getInstance() { + return INSTANCE; + } + + private static int getCacheSize() { + String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize", + System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE")); + if (Objects.nonNull(value)) { + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + LOG.warn("[OpenAiApiCache] Invalid cache size '{}', using default 500.", value); + } + } + return 500; + } + + /** + * Gets OpenAiApi or compute if absent. + * + * @param key the cache key (typically selectorId|configHash) + * @param openAiApiSupplier the supplier to create OpenAiApi if absent + * @return the cached or newly created OpenAiApi + */ + public OpenAiApi computeIfAbsent(final String key, final Supplier openAiApiSupplier) { + final int currentSize = openAiApiMap.size(); + if (currentSize > MAX_CACHE_SIZE) { + if (evictionInProgress.compareAndSet(false, true)) { + try { + synchronized (openAiApiMap) { + if (openAiApiMap.size() > MAX_CACHE_SIZE) { + evictOldestEntries(); + } + } + } finally { + evictionInProgress.set(false); + } + } + } + return openAiApiMap.computeIfAbsent(key, k -> openAiApiSupplier.get()); + } + + /** + * Evict oldest entries when cache size exceeds limit. + * Removes approximately 25% of entries to avoid thundering herd problem. + */ + private void evictOldestEntries() { + final int currentSize = openAiApiMap.size(); + if (currentSize <= MAX_CACHE_SIZE) { + return; + } + + final int evictCount = Math.max(10, currentSize / 4); + LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} oldest entries", + currentSize, MAX_CACHE_SIZE, evictCount); + + int removed = 0; + for (final String key : openAiApiMap.keySet()) { + if (removed >= evictCount) { + break; + } + openAiApiMap.remove(key); + removed++; + } + + LOG.info("[OpenAiApiCache] Evicted {} entries, cache size now: {}", removed, openAiApiMap.size()); + } + + /** + * Removes all cached OpenAiApi instances associated with a selector ID + * (by prefix matching "selectorId|"). + * + * @param selectorId the selector id + */ + public void remove(final String selectorId) { + if (Objects.isNull(selectorId)) { + return; + } + final String prefix = selectorId + "|"; + openAiApiMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix)); + LOG.info("[OpenAiApiCache] invalidate selectorId={} (by prefix)", selectorId); + } + + /** + * Clear all cached OpenAiApi instances. + */ + public void clearAll() { + openAiApiMap.clear(); + LOG.info("[OpenAiApiCache] cleared all cached OpenAiApi instances"); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java index 8b8868a04560..df9dc1a8cb44 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java @@ -23,6 +23,8 @@ import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; import org.apache.shenyu.common.enums.PluginEnum; import org.apache.shenyu.common.utils.GsonUtils; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache; import org.apache.shenyu.plugin.base.cache.CommonHandleCache; import org.apache.shenyu.plugin.base.handler.PluginDataHandler; import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; @@ -36,6 +38,9 @@ public class AiProxyPluginHandler implements PluginDataHandler { private final CommonHandleCache selectorCachedHandle = new CommonHandleCache<>(); + public AiProxyPluginHandler() { + } + @Override public void handlerPlugin(final PluginData pluginData) { // Note: The logic for handling global plugin configuration with Singleton has been removed @@ -45,6 +50,10 @@ public void handlerPlugin(final PluginData pluginData) { @Override public void handlerSelector(final SelectorData selectorData) { + // Invalidate the cache first when the selector is updated. + OpenAiApiCache.getInstance().remove(selectorData.getId()); + // Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events + // with refreshed realApiKey after selector changes. Removing here introduces a window of misses. if (Objects.isNull(selectorData.getHandle())) { return; } @@ -56,6 +65,9 @@ public void handlerSelector(final SelectorData selectorData) { @Override public void removeSelector(final SelectorData selectorData) { + // Invalidate the cache when the selector is removed. + OpenAiApiCache.getInstance().remove(selectorData.getId()); + AiProxyApiKeyCache.getInstance().removeBySelectorId(selectorData.getId()); selectorCachedHandle .removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE)); } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java index d25434881658..bde74f48668e 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java @@ -19,6 +19,7 @@ import org.apache.shenyu.common.dto.ProxyApiKeyData; import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache; import org.apache.shenyu.sync.data.api.AiProxyApiKeyDataSubscriber; import java.util.Objects; @@ -47,5 +48,6 @@ public void unSubscribe(final ProxyApiKeyData data) { @Override public void refresh() { AiProxyApiKeyCache.getInstance().refresh(); + OpenAiApiCache.getInstance().clearAll(); } } From 225cf0aa7a16d416d6da2534e2503545087e0d61 Mon Sep 17 00:00:00 2001 From: eye-gu <734164350@qq.com> Date: Thu, 14 May 2026 18:52:41 +0800 Subject: [PATCH 4/6] fix --- shenyu-admin/src/main/resources/application.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shenyu-admin/src/main/resources/application.yml b/shenyu-admin/src/main/resources/application.yml index 47b4ab7cf3e2..6d2c3af3bfe2 100755 --- a/shenyu-admin/src/main/resources/application.yml +++ b/shenyu-admin/src/main/resources/application.yml @@ -21,7 +21,7 @@ spring: application: name: shenyu-admin profiles: - active: mysql + active: h2 thymeleaf: cache: true encoding: utf-8 From 37de2d9e0f57844353713886530be90e0683b122 Mon Sep 17 00:00:00 2001 From: eye-gu <734164350@qq.com> Date: Fri, 15 May 2026 13:25:56 +0800 Subject: [PATCH 5/6] fix review --- .../protocol/OpenAiProtocolAdapter.java | 41 +++++++-- .../protocol/OpenAiProtocolAdapterTest.java | 21 ++++- .../ai/proxy/enhanced/AiProxyPlugin.java | 35 +++++--- .../proxy/enhanced/cache/OpenAiApiCache.java | 8 +- .../handler/AiProxyPluginHandler.java | 2 +- .../service/AiProxyExecutorService.java | 85 +++++++++++++++---- .../enhanced/service/UpstreamErrorLogger.java | 17 +++- .../ai/proxy/enhanced/AiProxyPluginTest.java | 18 ++-- .../handler/AiProxyPluginHandlerTest.java | 2 + .../service/AiProxyExecutorServiceTest.java | 32 ++++--- 10 files changed, 201 insertions(+), 60 deletions(-) diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java index ce44a5d3178b..a23eb9cd2ad2 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java @@ -17,6 +17,7 @@ package org.apache.shenyu.plugin.ai.common.protocol; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import org.apache.shenyu.common.utils.JsonUtils; @@ -27,9 +28,18 @@ /** * Adapts between OpenAI Chat Completions wire format and internal representations. + * + *

Note: deserialization into Spring AI's {@link ChatCompletionRequest} preserves fields + * modeled by that class (messages, model, temperature, maxTokens, stream, tools, etc.). + * Provider-specific extension fields not modeled by {@code ChatCompletionRequest} are dropped + * during deserialization. For the OpenAI-compatible providers this plugin currently supports, + * all required fields are covered. */ public final class OpenAiProtocolAdapter { + private static final com.fasterxml.jackson.databind.ObjectMapper MAPPER = + new com.fasterxml.jackson.databind.ObjectMapper(); + private OpenAiProtocolAdapter() { } @@ -45,7 +55,7 @@ public static boolean resolveStream(final String requestBody, final Boolean fall if (Objects.isNull(requestBody) || requestBody.isEmpty()) { return Boolean.TRUE.equals(fallbackStream); } - final JsonNode root = JsonUtils.toJsonNode(requestBody); + final JsonNode root = parseStrict(requestBody); if (Objects.isNull(root)) { return Boolean.TRUE.equals(fallbackStream); } @@ -56,8 +66,8 @@ public static boolean resolveStream(final String requestBody, final Boolean fall } /** - * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields - * including reasoning_content in assistant messages. + * Parse raw request body directly into ChatCompletionRequest, preserving fields + * modeled by Spring AI (including reasoning_content in assistant messages). * *

Spring AI's createRequest() loses reasoning_content, refusal, and annotations * when reconstructing ChatCompletionMessage from AssistantMessage. @@ -74,7 +84,7 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request } /** - * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields. + * Parse raw request body directly into ChatCompletionRequest, preserving modeled fields. * For model, temperature, max_tokens: client request takes priority; * if missing, falls back to the corresponding field in fallbackConfig. * @@ -88,14 +98,19 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request if (Objects.isNull(requestBody) || requestBody.isEmpty()) { throw new IllegalArgumentException("Request body must not be empty"); } - final JsonNode root = JsonUtils.toJsonNode(requestBody); + final JsonNode root = parseStrict(requestBody); if (Objects.isNull(root) || !root.isObject()) { throw new IllegalArgumentException("Invalid request body: expected a JSON object"); } final ObjectNode mutableRoot = (ObjectNode) root; if (root.hasNonNull("max_completion_tokens") && !root.hasNonNull("max_tokens")) { - mutableRoot.put("max_tokens", root.get("max_completion_tokens").asInt()); + final JsonNode tokenNode = root.get("max_completion_tokens"); + if (!tokenNode.isNumber()) { + throw new IllegalArgumentException( + "max_completion_tokens must be a number, got: " + tokenNode.getNodeType()); + } + mutableRoot.put("max_tokens", tokenNode.asInt()); mutableRoot.remove("max_completion_tokens"); } @@ -113,6 +128,18 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request mutableRoot.put("stream", stream); - return JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class); + final ChatCompletionRequest result = JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class); + if (Objects.isNull(result)) { + throw new IllegalArgumentException("Failed to parse request body into ChatCompletionRequest"); + } + return result; + } + + private static JsonNode parseStrict(final String json) { + try { + return MAPPER.readTree(json); + } catch (JsonProcessingException e) { + return null; + } } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java index 566824318014..adfbf4d4bf1f 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java @@ -157,7 +157,7 @@ void testFallbackTemperatureWhenClientMissing() { } @Test - void testNoFallbackTemperatureWhenConfigNull() { + void testNoFallbackWhenConfigTemperatureIsNull() { final AiCommonConfig config = new AiCommonConfig(); config.setTemperature(null); @@ -235,4 +235,23 @@ void testPartialClientFieldsWithPartialFallback() { assertEquals(0.1, req.temperature(), 0.001); assertEquals(2048, req.maxTokens()); } + + @Test + void testMalformedJsonThrows() { + assertThrows(IllegalArgumentException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest("not valid json{{{", false)); + } + + @Test + void testMaxCompletionTokensNonNumberThrows() { + final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":\"abc\"}"; + assertThrows(IllegalArgumentException.class, + () -> OpenAiProtocolAdapter.toChatCompletionRequest(body, false)); + } + + @Test + void testResolveStreamMalformedJsonFallsBack() { + assertFalse(OpenAiProtocolAdapter.resolveStream("not json{{{", false)); + assertTrue(OpenAiProtocolAdapter.resolveStream("not json{{{", true)); + } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java index b1650f37f3f3..517ff280bc4d 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java @@ -30,6 +30,7 @@ import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; +import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService.FallbackContext; import org.apache.shenyu.plugin.ai.proxy.enhanced.service.UpstreamErrorLogger; import org.apache.shenyu.plugin.api.ShenyuPluginChain; import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils; @@ -140,7 +141,8 @@ protected Mono doExecute( } } - if (OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream())) { + final boolean stream = OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream()); + if (stream) { return handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); } return handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); @@ -155,13 +157,13 @@ private Mono handleStreamRequest( final AiProxyHandle selectorHandle) { final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig); final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig); - final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle, - requestBody); + final Optional fallbackCtx = resolveFallbackContext( + selector.getId(), primaryConfig, selectorHandle, requestBody); final ServerHttpResponse response = exchange.getResponse(); response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM); final Flux chunkFlux = aiProxyExecutorService.executeDirectStream( - mainApi, fallbackApi, request); + mainApi, fallbackCtx, request, requestBody, true); final Flux sseFlux = chunkFlux.map( chunk -> { @@ -185,11 +187,11 @@ private Mono handleNonStreamRequest( final AiProxyHandle selectorHandle) { final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig); final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig); - final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle, - requestBody); + final Optional fallbackCtx = resolveFallbackContext( + selector.getId(), primaryConfig, selectorHandle, requestBody); return aiProxyExecutorService - .executeDirectCall(mainApi, fallbackApi, request) + .executeDirectCall(mainApi, fallbackCtx, request, requestBody) .flatMap( responseEntity -> { final String responseJson = JsonUtils.toJson(responseEntity.getBody()); @@ -199,7 +201,7 @@ private Mono handleNonStreamRequest( .doOnError(e -> logUpstreamError(e, "non-stream")); } - private Optional resolveFallbackOpenAiApi( + private Optional resolveFallbackContext( final String selectorId, final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle, @@ -211,7 +213,7 @@ private Optional resolveFallbackOpenAiApi( if (LOG.isDebugEnabled()) { LOG.debug("[AiProxy] dynamic fallback config: {}", cfg); } - return getCachedOpenAiApi(selectorId, "dynamicFallback", cfg); + return new FallbackContext(createOpenAiApi(cfg), cfg); }) .or(() -> aiProxyConfigService .resolveAdminFallbackConfig(primaryConfig, selectorHandle) @@ -220,15 +222,26 @@ private Optional resolveFallbackOpenAiApi( if (LOG.isDebugEnabled()) { LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig); } - return getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig); + return new FallbackContext( + getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig), + adminFallbackConfig); })); } private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type, final AiCommonConfig config) { - final String cacheKey = selectorId + "|" + type + "_" + config.hashCode(); + final String cacheKey = selectorId + "|" + type + "_" + generateConfigCacheKey(config); return OpenAiApiCache.getInstance().computeIfAbsent(cacheKey, () -> createOpenAiApi(config)); } + private int generateConfigCacheKey(final AiCommonConfig config) { + return Objects.hash( + config.getBaseUrl(), + config.getModel(), + config.getTemperature(), + config.getMaxTokens() + ); + } + private OpenAiApi createOpenAiApi(final AiCommonConfig config) { if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) { throw new IllegalArgumentException("baseUrl must not be empty"); diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java index 886b746ad94b..fdbe22f6e429 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java @@ -97,8 +97,10 @@ public OpenAiApi computeIfAbsent(final String key, final Supplier ope } /** - * Evict oldest entries when cache size exceeds limit. - * Removes approximately 25% of entries to avoid thundering herd problem. + * Evict arbitrary entries when cache size exceeds limit. + * Note: ConcurrentHashMap iteration order is undefined, so evicted entries are arbitrary, + * not guaranteed to be the oldest. Removes approximately 25% of entries to avoid + * thundering herd problem. */ private void evictOldestEntries() { final int currentSize = openAiApiMap.size(); @@ -107,7 +109,7 @@ private void evictOldestEntries() { } final int evictCount = Math.max(10, currentSize / 4); - LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} oldest entries", + LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} arbitrary entries", currentSize, MAX_CACHE_SIZE, evictCount); int removed = 0; diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java index df9dc1a8cb44..cc8a42bbd4b8 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java @@ -54,7 +54,7 @@ public void handlerSelector(final SelectorData selectorData) { OpenAiApiCache.getInstance().remove(selectorData.getId()); // Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events // with refreshed realApiKey after selector changes. Removing here introduces a window of misses. - if (Objects.isNull(selectorData.getHandle())) { + if (Objects.isNull(selectorData.getHandle()) || selectorData.getHandle().isEmpty()) { return; } AiProxyHandle aiProxyHandle = GsonUtils.getInstance().fromJson(selectorData.getHandle(), AiProxyHandle.class); diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java index 8502a772f43e..8fbd37be2186 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java @@ -17,6 +17,8 @@ package org.apache.shenyu.plugin.ai.proxy.enhanced.service; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.openai.api.OpenAiApi; @@ -25,12 +27,14 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.NonTransientAiException; import org.springframework.http.ResponseEntity; +import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.util.retry.Retry; import java.time.Duration; +import java.util.Objects; import java.util.Optional; /** @@ -44,16 +48,20 @@ public class AiProxyExecutorService { * Execute a streaming AI call directly via {@link OpenAiApi}, bypassing Spring AI's * {@code createRequest()} which loses fields like {@code reasoning_content}. * - * @param mainApi the main OpenAiApi - * @param fallbackApiOpt the optional fallback OpenAiApi - * @param request the ChatCompletionRequest with all fields preserved + * @param mainApi the main OpenAiApi + * @param fallbackCtxOpt the optional fallback context (api + config) + * @param request the ChatCompletionRequest with all fields preserved + * @param requestBody the original request body for rebuilding fallback request + * @param stream whether this is a streaming request * @return a Flux of ChatCompletionChunk */ public Flux executeDirectStream(final OpenAiApi mainApi, - final Optional fallbackApiOpt, final ChatCompletionRequest request) { + final Optional fallbackCtxOpt, final ChatCompletionRequest request, + final String requestBody, final boolean stream) { return mainApi.chatCompletionStream(request) .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct stream")) .retryWhen(Retry.max(1) + .filter(AiProxyExecutorService::isRetryable) .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> { LOG.warn("Direct stream retry exhausted. Triggering fallback.", retrySignal.failure()); @@ -62,24 +70,26 @@ public Flux executeDirectStream(final OpenAiApi mainApi, retrySignal.failure()); })) .onErrorResume(NonTransientAiException.class, - throwable -> handleDirectFallbackStream(throwable, fallbackApiOpt, request)); + throwable -> handleDirectFallbackStream(throwable, fallbackCtxOpt, requestBody, stream)); } /** * Execute a non-streaming AI call directly via {@link OpenAiApi}. * - * @param mainApi the main OpenAiApi - * @param fallbackApiOpt the optional fallback OpenAiApi - * @param request the ChatCompletionRequest with all fields preserved + * @param mainApi the main OpenAiApi + * @param fallbackCtxOpt the optional fallback context (api + config) + * @param request the ChatCompletionRequest with all fields preserved + * @param requestBody the original request body for rebuilding fallback request * @return a Mono of ResponseEntity containing ChatCompletion */ public Mono> executeDirectCall(final OpenAiApi mainApi, - final Optional fallbackApiOpt, final ChatCompletionRequest request) { + final Optional fallbackCtxOpt, final ChatCompletionRequest request, + final String requestBody) { return Mono.fromCallable(() -> mainApi.chatCompletionEntity(request)) .subscribeOn(Schedulers.boundedElastic()) .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct call")) .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)) - .filter(throwable -> !(throwable instanceof NonTransientAiException)) + .filter(AiProxyExecutorService::isRetryable) .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> { LOG.warn("Direct call retries exhausted after {} attempts. Triggering fallback.", retrySignal.totalRetries(), retrySignal.failure()); @@ -87,31 +97,72 @@ public Mono> executeDirectCall(final OpenAiApi ma retrySignal.failure()); })) .onErrorResume(NonTransientAiException.class, - throwable -> handleDirectFallbackCall(throwable, fallbackApiOpt, request)); + throwable -> handleDirectFallbackCall(throwable, fallbackCtxOpt, requestBody)); } private Flux handleDirectFallbackStream(final Throwable throwable, - final Optional fallbackApiOpt, final ChatCompletionRequest request) { + final Optional fallbackCtxOpt, final String requestBody, final boolean stream) { LOG.warn("Main direct stream failed, attempting fallback...", throwable); - if (fallbackApiOpt.isEmpty()) { + if (fallbackCtxOpt.isEmpty()) { return Flux.error(throwable); } + final FallbackContext ctx = fallbackCtxOpt.get(); LOG.info("Using fallback OpenAiApi for direct stream"); - return fallbackApiOpt.get().chatCompletionStream(request); + final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest( + requestBody, stream, ctx.config()); + return ctx.api().chatCompletionStream(fallbackRequest); } private Mono> handleDirectFallbackCall(final Throwable throwable, - final Optional fallbackApiOpt, final ChatCompletionRequest request) { + final Optional fallbackCtxOpt, final String requestBody) { LOG.warn("Main direct call failed, attempting fallback...", throwable); - if (fallbackApiOpt.isEmpty()) { + if (fallbackCtxOpt.isEmpty()) { return Mono.error(throwable); } + final FallbackContext ctx = fallbackCtxOpt.get(); LOG.info("Using fallback OpenAiApi for direct call"); - return Mono.fromCallable(() -> fallbackApiOpt.get().chatCompletionEntity(request)) + final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest( + requestBody, false, ctx.config()); + return Mono.fromCallable(() -> ctx.api().chatCompletionEntity(fallbackRequest)) .subscribeOn(Schedulers.boundedElastic()); } + + /** + * Determine if the error is retryable. + * Retries transient network errors and retryable HTTP statuses (429, 5xx). + * Non-retryable: NonTransientAiException, client errors (400/401/403/404). + */ + private static boolean isRetryable(final Throwable throwable) { + if (throwable instanceof NonTransientAiException) { + return false; + } + final WebClientResponseException webClientEx = findWebClientResponseException(throwable); + if (Objects.nonNull(webClientEx)) { + final int status = webClientEx.getStatusCode().value(); + return status == 429 || status >= 500; + } + return true; + } + + private static WebClientResponseException findWebClientResponseException(final Throwable e) { + Throwable current = e; + while (Objects.nonNull(current)) { + if (current instanceof WebClientResponseException ex) { + return ex; + } + current = current.getCause(); + } + return null; + } + + /** + * Container for fallback context: the OpenAiApi instance and the fallback config + * used to rebuild the request with fallback-specific model/temperature/maxTokens. + */ + public record FallbackContext(OpenAiApi api, AiCommonConfig config) { + } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java index bee49aab7d0c..bae715f35026 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java @@ -27,19 +27,34 @@ */ public final class UpstreamErrorLogger { + private static final int MAX_BODY_LOG_LENGTH = 512; + private UpstreamErrorLogger() { } public static void logUpstreamError(final Logger log, final Throwable e, final String mode) { + if (Objects.isNull(e)) { + return; + } final WebClientResponseException webClientEx = findWebClientResponseException(e); if (Objects.nonNull(webClientEx)) { log.error("[AiProxy] {} failed, status={}, upstreamBody={}", - mode, webClientEx.getStatusCode(), webClientEx.getResponseBodyAsString(), e); + mode, webClientEx.getStatusCode(), truncateBody(webClientEx.getResponseBodyAsString()), e); } else { log.error("[AiProxy] {} failed", mode, e); } } + private static String truncateBody(final String body) { + if (Objects.isNull(body)) { + return "null"; + } + if (body.length() <= MAX_BODY_LOG_LENGTH) { + return body; + } + return body.substring(0, MAX_BODY_LOG_LENGTH) + "...(truncated, total " + body.length() + " chars)"; + } + private static WebClientResponseException findWebClientResponseException(final Throwable e) { Throwable current = e; while (Objects.nonNull(current)) { diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java index dfbf95a1b0c9..b3bb711af1d4 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java @@ -131,8 +131,8 @@ private void setupSuccessMocks(final AiProxyHandle handle, final AiCommonConfig when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(fallbackConfig); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(fallbackConfig); - when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity)); - when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Flux.empty()); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity)); + when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class), any(Boolean.class))).thenReturn(Flux.empty()); } @Test @@ -151,7 +151,7 @@ public void testExecuteSuccess() { verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test @@ -172,7 +172,7 @@ public void testExecuteWithDynamicFallback() { StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test @@ -251,7 +251,7 @@ public void testExecuteWithAdminFallback() { StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); - verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test @@ -275,7 +275,7 @@ public void testCacheIsUsedForAdminFallbackClient() { when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig); when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty()); when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig)); - when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity)); // Execute the test - focus on successful execution rather than cache verification StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) @@ -284,7 +284,7 @@ public void testCacheIsUsedForAdminFallbackClient() { // Verify that the configuration methods were called correctly verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } @Test @@ -315,7 +315,7 @@ public void testExecutorServiceError() { final RuntimeException exception = new RuntimeException("AI execution failed"); setupSuccessMocks(handle, primaryConfig, Optional.empty()); - when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.error(exception)); + when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.error(exception)); StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .expectErrorMatches(exception::equals) @@ -324,6 +324,6 @@ public void testExecutorServiceError() { verify(configService).resolvePrimaryConfig(handle); verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY); verify(configService).resolveAdminFallbackConfig(primaryConfig, handle); - verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class)); + verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class)); } } \ No newline at end of file diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java index 248ae16ac883..5aa2a7a5edf2 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java @@ -78,6 +78,8 @@ void testHandlerSelectorWithEmptyHandle() { selector.setId("sel-3"); selector.setHandle(""); + handler.handlerSelector(selector); + String key = CacheKeyUtils.INST.getKey("sel-3", Constants.DEFAULT_RULE); assertNull(handler.getSelectorCachedHandle().obtainHandle(key)); } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java index 46d3719cab2a..38597d0dbedf 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java @@ -17,6 +17,7 @@ package org.apache.shenyu.plugin.ai.proxy.enhanced.service; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,6 +33,7 @@ import java.util.Optional; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -40,6 +42,8 @@ @ExtendWith(MockitoExtension.class) public class AiProxyExecutorServiceTest { + private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; + private AiProxyExecutorService executorService; @BeforeEach @@ -54,7 +58,7 @@ void testExecuteDirectStreamSuccess() { final ChatCompletionRequest request = mock(ChatCompletionRequest.class); when(mainApi.chatCompletionStream(request)).thenReturn(Flux.just(chunk)); - StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request)) + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true)) .expectNext(chunk) .verifyComplete(); @@ -67,7 +71,7 @@ void testExecuteDirectStreamErrorWithNoFallback() { final ChatCompletionRequest request = mock(ChatCompletionRequest.class); when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error"))); - StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request)) + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true)) .expectError(NonTransientAiException.class) .verify(); } @@ -80,13 +84,17 @@ void testExecuteDirectStreamFallbackSuccess() { final ChatCompletionChunk fallbackChunk = mock(ChatCompletionChunk.class); when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error"))); - when(fallbackApi.chatCompletionStream(request)).thenReturn(Flux.just(fallbackChunk)); + when(fallbackApi.chatCompletionStream(any(ChatCompletionRequest.class))).thenReturn(Flux.just(fallbackChunk)); + + final AiCommonConfig fallbackConfig = new AiCommonConfig(); + fallbackConfig.setModel("fallback-model"); + final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig); - StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(fallbackApi), request)) + StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(ctx), request, REQUEST_BODY, true)) .expectNext(fallbackChunk) .verifyComplete(); - verify(fallbackApi, times(1)).chatCompletionStream(request); + verify(fallbackApi, times(1)).chatCompletionStream(any(ChatCompletionRequest.class)); } @Test @@ -97,7 +105,7 @@ void testExecuteDirectCallSuccess() { final ResponseEntity responseEntity = ResponseEntity.ok(completion); when(mainApi.chatCompletionEntity(request)).thenReturn(responseEntity); - StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request)) + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY)) .expectNext(responseEntity) .verifyComplete(); @@ -110,7 +118,7 @@ void testExecuteDirectCallErrorWithNoFallback() { final ChatCompletionRequest request = mock(ChatCompletionRequest.class); when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error")); - StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request)) + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY)) .expectError(NonTransientAiException.class) .verify(); } @@ -124,12 +132,16 @@ void testExecuteDirectCallFallbackSuccess() { final ResponseEntity fallbackResponse = ResponseEntity.ok(fallbackCompletion); when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error")); - when(fallbackApi.chatCompletionEntity(request)).thenReturn(fallbackResponse); + when(fallbackApi.chatCompletionEntity(any(ChatCompletionRequest.class))).thenReturn(fallbackResponse); + + final AiCommonConfig fallbackConfig = new AiCommonConfig(); + fallbackConfig.setModel("fallback-model"); + final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig); - StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(fallbackApi), request)) + StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(ctx), request, REQUEST_BODY)) .expectNext(fallbackResponse) .verifyComplete(); - verify(fallbackApi, times(1)).chatCompletionEntity(request); + verify(fallbackApi, times(1)).chatCompletionEntity(any(ChatCompletionRequest.class)); } } From c6a6dd6cfdd0968b2004127c05b4584dcaff2bfe Mon Sep 17 00:00:00 2001 From: eye-gu <734164350@qq.com> Date: Sun, 17 May 2026 19:00:25 +0800 Subject: [PATCH 6/6] fix review --- .../protocol/OpenAiProtocolAdapter.java | 76 +++++++++------- .../protocol/OpenAiProtocolAdapterTest.java | 67 +++++++++----- .../ai/proxy/enhanced/AiProxyPlugin.java | 1 + .../proxy/enhanced/cache/OpenAiApiCache.java | 13 ++- .../handler/AiProxyPluginHandler.java | 3 + .../service/AiProxyExecutorService.java | 19 +--- .../enhanced/service/UpstreamErrorLogger.java | 8 +- .../ai/proxy/enhanced/AiProxyPluginTest.java | 4 +- .../enhanced/cache/OpenAiApiCacheTest.java | 91 +++++++++++++++++++ .../service/UpstreamErrorLoggerTest.java | 15 +++ 10 files changed, 222 insertions(+), 75 deletions(-) create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java index a23eb9cd2ad2..73244c838f82 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java @@ -20,8 +20,10 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; -import org.apache.shenyu.common.utils.JsonUtils; +import org.apache.shenyu.common.exception.ShenyuException; import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import java.util.Objects; @@ -37,9 +39,24 @@ */ public final class OpenAiProtocolAdapter { + private static final Logger LOG = LoggerFactory.getLogger(OpenAiProtocolAdapter.class); + private static final com.fasterxml.jackson.databind.ObjectMapper MAPPER = new com.fasterxml.jackson.databind.ObjectMapper(); + /** + * OpenAI API field name constants. + */ + private static final String FIELD_MODEL = "model"; + + private static final String FIELD_TEMPERATURE = "temperature"; + + private static final String FIELD_STREAM = "stream"; + + private static final String FIELD_MAX_COMPLETION_TOKENS = "max_completion_tokens"; + + private static final String FIELD_MAX_TOKENS = "max_tokens"; + private OpenAiProtocolAdapter() { } @@ -59,8 +76,8 @@ public static boolean resolveStream(final String requestBody, final Boolean fall if (Objects.isNull(root)) { return Boolean.TRUE.equals(fallbackStream); } - if (root.hasNonNull("stream")) { - return root.get("stream").asBoolean(); + if (root.hasNonNull(FIELD_STREAM)) { + return root.get(FIELD_STREAM).asBoolean(); } return Boolean.TRUE.equals(fallbackStream); } @@ -73,11 +90,13 @@ public static boolean resolveStream(final String requestBody, final Boolean fall * when reconstructing ChatCompletionMessage from AssistantMessage. * This method avoids that loss by deserializing the raw JSON directly. * - *

Also converts max_completion_tokens to max_tokens for broader API compatibility. + *

For max_tokens and max_completion_tokens: client values are preserved as-is; + * only when the client omits both fields does the fallbackConfig value populate both, + * ensuring compatibility with providers that require either field. * * @param requestBody the raw JSON request body in OpenAI Chat Completions format * @param stream whether this is a streaming request (sets the stream field) - * @return a ChatCompletionRequest with all fields preserved from the original request + * @return a ChatCompletionRequest with modeled fields preserved from the original request */ public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, final boolean stream) { return toChatCompletionRequest(requestBody, stream, null); @@ -85,54 +104,47 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request /** * Parse raw request body directly into ChatCompletionRequest, preserving modeled fields. - * For model, temperature, max_tokens: client request takes priority; - * if missing, falls back to the corresponding field in fallbackConfig. + * When fallbackConfig is provided, its non-null fields (model, temperature, maxTokens) + * override the client request values, matching the original ChatModel-based fallback behavior + * where the fallback ChatModel's config takes precedence. * * @param requestBody the raw JSON request body in OpenAI Chat Completions format * @param stream whether this is a streaming request - * @param fallbackConfig the admin config used as fallback when client omits fields - * @return a ChatCompletionRequest with all fields preserved + * @param fallbackConfig the fallback config whose non-null fields override client values + * @return a ChatCompletionRequest with modeled fields preserved */ public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, final boolean stream, final AiCommonConfig fallbackConfig) { if (Objects.isNull(requestBody) || requestBody.isEmpty()) { - throw new IllegalArgumentException("Request body must not be empty"); + throw new ShenyuException("Request body must not be empty"); } final JsonNode root = parseStrict(requestBody); if (Objects.isNull(root) || !root.isObject()) { - throw new IllegalArgumentException("Invalid request body: expected a JSON object"); + throw new ShenyuException("Invalid request body: expected a JSON object"); } final ObjectNode mutableRoot = (ObjectNode) root; - if (root.hasNonNull("max_completion_tokens") && !root.hasNonNull("max_tokens")) { - final JsonNode tokenNode = root.get("max_completion_tokens"); - if (!tokenNode.isNumber()) { - throw new IllegalArgumentException( - "max_completion_tokens must be a number, got: " + tokenNode.getNodeType()); - } - mutableRoot.put("max_tokens", tokenNode.asInt()); - mutableRoot.remove("max_completion_tokens"); - } - if (Objects.nonNull(fallbackConfig)) { - if (!root.hasNonNull("model") && Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) { - mutableRoot.put("model", fallbackConfig.getModel()); + if (Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) { + mutableRoot.put(FIELD_MODEL, fallbackConfig.getModel()); } - if (!root.hasNonNull("temperature") && Objects.nonNull(fallbackConfig.getTemperature())) { - mutableRoot.put("temperature", fallbackConfig.getTemperature()); + if (Objects.nonNull(fallbackConfig.getTemperature())) { + mutableRoot.put(FIELD_TEMPERATURE, fallbackConfig.getTemperature()); } - if (!root.hasNonNull("max_tokens") && Objects.nonNull(fallbackConfig.getMaxTokens())) { - mutableRoot.put("max_tokens", fallbackConfig.getMaxTokens()); + if (Objects.nonNull(fallbackConfig.getMaxTokens())) { + mutableRoot.put(FIELD_MAX_TOKENS, fallbackConfig.getMaxTokens()); + mutableRoot.put(FIELD_MAX_COMPLETION_TOKENS, fallbackConfig.getMaxTokens()); } } - mutableRoot.put("stream", stream); + mutableRoot.put(FIELD_STREAM, stream); - final ChatCompletionRequest result = JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class); - if (Objects.isNull(result)) { - throw new IllegalArgumentException("Failed to parse request body into ChatCompletionRequest"); + try { + return MAPPER.treeToValue(mutableRoot, ChatCompletionRequest.class); + } catch (Exception e) { + LOG.error("[AiProxy] Failed to parse request body into ChatCompletionRequest", e); + throw new ShenyuException("Failed to parse request body into ChatCompletionRequest", e); } - return result; } private static JsonNode parseStrict(final String json) { diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java index adfbf4d4bf1f..e843d95f1d0b 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java @@ -17,6 +17,7 @@ package org.apache.shenyu.plugin.ai.common.protocol; +import org.apache.shenyu.common.exception.ShenyuException; import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; import org.junit.jupiter.api.Test; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; @@ -33,13 +34,13 @@ public final class OpenAiProtocolAdapterTest { @Test void testNullRequestBody() { - assertThrows(IllegalArgumentException.class, + assertThrows(ShenyuException.class, () -> OpenAiProtocolAdapter.toChatCompletionRequest(null, false)); } @Test void testEmptyRequestBody() { - assertThrows(IllegalArgumentException.class, + assertThrows(ShenyuException.class, () -> OpenAiProtocolAdapter.toChatCompletionRequest("", false)); } @@ -87,21 +88,21 @@ void testStreamFlagUnset() { } @Test - void testMaxCompletionTokensConvertedToMaxTokens() { + void testMaxCompletionTokensPreservedAsIs() { final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":100}"; final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false); assertNotNull(req); - assertEquals(100, req.maxTokens()); + assertEquals(100, req.maxCompletionTokens()); } @Test - void testClientModelTakesPriority() { + void testFallbackModelOverridesClientModel() { final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; final AiCommonConfig config = new AiCommonConfig(); - config.setModel("admin-model"); + config.setModel("fallback-model"); final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); - assertEquals("client-model", req.model()); + assertEquals("fallback-model", req.model()); } @Test @@ -138,13 +139,13 @@ void testNoFallbackWhenConfigModelIsEmpty() { } @Test - void testClientTemperatureTakesPriority() { + void testFallbackTemperatureOverridesClient() { final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.5}"; final AiCommonConfig config = new AiCommonConfig(); config.setTemperature(0.9); final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); - assertEquals(0.5, req.temperature(), 0.001); + assertEquals(0.9, req.temperature(), 0.001); } @Test @@ -166,13 +167,13 @@ void testNoFallbackWhenConfigTemperatureIsNull() { } @Test - void testClientMaxTokensTakesPriority() { + void testFallbackMaxTokensOverridesClient() { final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_tokens\":200}"; final AiCommonConfig config = new AiCommonConfig(); config.setMaxTokens(500); final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); - assertEquals(200, req.maxTokens()); + assertEquals(500, req.maxTokens()); } @Test @@ -208,44 +209,44 @@ void testAllFallbackFieldsAppliedWhenClientMissingAll() { } @Test - void testAllClientFieldsTakePriority() { + void testAllFallbackFieldsOverrideClient() { final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]," + "\"temperature\":0.1,\"max_tokens\":50}"; final AiCommonConfig config = new AiCommonConfig(); - config.setModel("admin-model"); + config.setModel("fallback-model"); config.setTemperature(0.9); config.setMaxTokens(9999); final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); - assertEquals("client-model", req.model()); - assertEquals(0.1, req.temperature(), 0.001); - assertEquals(50, req.maxTokens()); + assertEquals("fallback-model", req.model()); + assertEquals(0.9, req.temperature(), 0.001); + assertEquals(9999, req.maxTokens()); } @Test - void testPartialClientFieldsWithPartialFallback() { + void testPartialFallbackOverrideWithPartialClient() { final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.1}"; final AiCommonConfig config = new AiCommonConfig(); - config.setModel("admin-model"); + config.setModel("fallback-model"); config.setTemperature(0.9); config.setMaxTokens(2048); final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config); - assertEquals("client-model", req.model()); - assertEquals(0.1, req.temperature(), 0.001); + assertEquals("fallback-model", req.model()); + assertEquals(0.9, req.temperature(), 0.001); assertEquals(2048, req.maxTokens()); } @Test void testMalformedJsonThrows() { - assertThrows(IllegalArgumentException.class, + assertThrows(ShenyuException.class, () -> OpenAiProtocolAdapter.toChatCompletionRequest("not valid json{{{", false)); } @Test void testMaxCompletionTokensNonNumberThrows() { final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":\"abc\"}"; - assertThrows(IllegalArgumentException.class, + assertThrows(ShenyuException.class, () -> OpenAiProtocolAdapter.toChatCompletionRequest(body, false)); } @@ -254,4 +255,26 @@ void testResolveStreamMalformedJsonFallsBack() { assertFalse(OpenAiProtocolAdapter.resolveStream("not json{{{", false)); assertTrue(OpenAiProtocolAdapter.resolveStream("not json{{{", true)); } + + @Test + void testAnnotationsRoundTrip() throws Exception { + final com.fasterxml.jackson.databind.ObjectMapper mapper = new com.fasterxml.jackson.databind.ObjectMapper(); + final String body = "{\"messages\":[" + + "{\"role\":\"user\",\"content\":\"hello\"}," + + "{\"role\":\"assistant\",\"content\":\"see link\"," + + "\"annotations\":[{\"type\":\"url_citation\"," + + "\"url_citation\":{\"url\":\"https://example.com\",\"title\":\"Example\",\"start_index\":0,\"end_index\":8}}]}" + + "]}"; + final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false); + assertNotNull(req); + + final String serialized = mapper.writeValueAsString(req); + final com.fasterxml.jackson.databind.JsonNode root = mapper.readTree(serialized); + final com.fasterxml.jackson.databind.JsonNode assistantMsg = root.get("messages").get(1); + assertTrue(assistantMsg.has("annotations")); + final com.fasterxml.jackson.databind.JsonNode annotations = assistantMsg.get("annotations"); + assertEquals(1, annotations.size()); + assertEquals("url_citation", annotations.get(0).get("type").asText()); + assertEquals("https://example.com", annotations.get(0).get("url_citation").get("url").asText()); + } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java index 517ff280bc4d..916431437acd 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java @@ -236,6 +236,7 @@ private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type, private int generateConfigCacheKey(final AiCommonConfig config) { return Objects.hash( config.getBaseUrl(), + config.getApiKey(), config.getModel(), config.getTemperature(), config.getMaxTokens() diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java index fdbe22f6e429..d75aab480f32 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java @@ -85,7 +85,7 @@ public OpenAiApi computeIfAbsent(final String key, final Supplier ope try { synchronized (openAiApiMap) { if (openAiApiMap.size() > MAX_CACHE_SIZE) { - evictOldestEntries(); + evictEntries(); } } } finally { @@ -102,7 +102,7 @@ public OpenAiApi computeIfAbsent(final String key, final Supplier ope * not guaranteed to be the oldest. Removes approximately 25% of entries to avoid * thundering herd problem. */ - private void evictOldestEntries() { + private void evictEntries() { final int currentSize = openAiApiMap.size(); if (currentSize <= MAX_CACHE_SIZE) { return; @@ -146,4 +146,13 @@ public void clearAll() { openAiApiMap.clear(); LOG.info("[OpenAiApiCache] cleared all cached OpenAiApi instances"); } + + /** + * Gets the current cache size (for testing / monitoring). + * + * @return the number of cached entries + */ + public int size() { + return openAiApiMap.size(); + } } diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java index cc8a42bbd4b8..40dfc31fa77f 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java @@ -55,6 +55,9 @@ public void handlerSelector(final SelectorData selectorData) { // Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events // with refreshed realApiKey after selector changes. Removing here introduces a window of misses. if (Objects.isNull(selectorData.getHandle()) || selectorData.getHandle().isEmpty()) { + // Clear stale cached handle when selector handle is removed/cleared + selectorCachedHandle + .removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE)); return; } AiProxyHandle aiProxyHandle = GsonUtils.getInstance().fromJson(selectorData.getHandle(), AiProxyHandle.class); diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java index 8fbd37be2186..534f81f0c7f7 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java @@ -69,8 +69,7 @@ public Flux executeDirectStream(final OpenAiApi mainApi, "Direct stream failed after 1 retry. Triggering fallback.", retrySignal.failure()); })) - .onErrorResume(NonTransientAiException.class, - throwable -> handleDirectFallbackStream(throwable, fallbackCtxOpt, requestBody, stream)); + .onErrorResume(e -> handleDirectFallbackStream(e, fallbackCtxOpt, requestBody, stream)); } /** @@ -96,8 +95,7 @@ public Mono> executeDirectCall(final OpenAiApi ma return new NonTransientAiException("Direct call retries exhausted. Triggering fallback.", retrySignal.failure()); })) - .onErrorResume(NonTransientAiException.class, - throwable -> handleDirectFallbackCall(throwable, fallbackCtxOpt, requestBody)); + .onErrorResume(e -> handleDirectFallbackCall(e, fallbackCtxOpt, requestBody)); } private Flux handleDirectFallbackStream(final Throwable throwable, @@ -140,7 +138,7 @@ private static boolean isRetryable(final Throwable throwable) { if (throwable instanceof NonTransientAiException) { return false; } - final WebClientResponseException webClientEx = findWebClientResponseException(throwable); + final WebClientResponseException webClientEx = UpstreamErrorLogger.findWebClientResponseException(throwable); if (Objects.nonNull(webClientEx)) { final int status = webClientEx.getStatusCode().value(); return status == 429 || status >= 500; @@ -148,17 +146,6 @@ private static boolean isRetryable(final Throwable throwable) { return true; } - private static WebClientResponseException findWebClientResponseException(final Throwable e) { - Throwable current = e; - while (Objects.nonNull(current)) { - if (current instanceof WebClientResponseException ex) { - return ex; - } - current = current.getCause(); - } - return null; - } - /** * Container for fallback context: the OpenAiApi instance and the fallback config * used to rebuild the request with fallback-specific model/temperature/maxTokens. diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java index bae715f35026..06a08e0c378a 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java @@ -55,7 +55,13 @@ private static String truncateBody(final String body) { return body.substring(0, MAX_BODY_LOG_LENGTH) + "...(truncated, total " + body.length() + " chars)"; } - private static WebClientResponseException findWebClientResponseException(final Throwable e) { + /** + * Find the {@link WebClientResponseException} in the exception chain. + * + * @param e the root exception + * @return the WebClientResponseException if found, null otherwise + */ + public static WebClientResponseException findWebClientResponseException(final Throwable e) { Throwable current = e; while (Objects.nonNull(current)) { if (current instanceof WebClientResponseException ex) { diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java index b3bb711af1d4..e61f43e1689d 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java @@ -255,7 +255,7 @@ public void testExecuteWithAdminFallback() { } @Test - public void testCacheIsUsedForAdminFallbackClient() { + public void testAdminFallbackExecution() { final AiProxyHandle handle = new AiProxyHandle(); final AiCommonConfig primaryConfig = new AiCommonConfig(); primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName()); @@ -277,7 +277,7 @@ public void testCacheIsUsedForAdminFallbackClient() { when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig)); when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity)); - // Execute the test - focus on successful execution rather than cache verification + // Execute the test - verify admin fallback path is exercised StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule)) .verifyComplete(); diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java new file mode 100644 index 000000000000..11bed2897ddb --- /dev/null +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.cache; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.openai.api.OpenAiApi; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +class OpenAiApiCacheTest { + + private OpenAiApiCache cache; + + @BeforeEach + void setUp() { + cache = new OpenAiApiCache(); + cache.clearAll(); + } + + @Test + void testComputeIfAbsentReusesExisting() { + final OpenAiApi api = createTestApi("https://api.openai.com", "key1"); + final OpenAiApi cached = cache.computeIfAbsent("key1", () -> api); + final OpenAiApi reused = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1-other")); + + assertSame(cached, reused); + assertEquals(1, cache.size()); + } + + @Test + void testComputeIfAbsentCreatesNewForDifferentKey() { + final OpenAiApi api1 = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1")); + final OpenAiApi api2 = cache.computeIfAbsent("key2", () -> createTestApi("https://api.deepseek.com", "key2")); + + // Different keys produce different instances + assertEquals(2, cache.size()); + } + + @Test + void testRemoveBySelectorId() { + cache.computeIfAbsent("sel1|main_123", () -> createTestApi("https://a.com", "k1")); + cache.computeIfAbsent("sel1|adminFallback_456", () -> createTestApi("https://b.com", "k2")); + cache.computeIfAbsent("sel2|main_789", () -> createTestApi("https://c.com", "k3")); + + assertEquals(3, cache.size()); + + cache.remove("sel1"); + + assertEquals(1, cache.size()); + } + + @Test + void testRemoveWithNullSelectorIdDoesNotThrow() { + cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1")); + cache.remove(null); + assertEquals(1, cache.size()); + } + + @Test + void testClearAll() { + cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1")); + cache.computeIfAbsent("key2", () -> createTestApi("https://b.com", "k2")); + + assertEquals(2, cache.size()); + + cache.clearAll(); + + assertEquals(0, cache.size()); + } + + private OpenAiApi createTestApi(final String baseUrl, final String apiKey) { + return OpenAiApi.builder().baseUrl(baseUrl).apiKey(apiKey).build(); + } +} diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java index 46d21f43c9a1..85193687d1df 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java @@ -70,4 +70,19 @@ void testLogWithNullExceptionDoesNotThrow() { Logger log = mock(Logger.class); assertDoesNotThrow(() -> UpstreamErrorLogger.logUpstreamError(log, null, "stream")); } + + @Test + void testLogWithLongBodyIsTruncated() { + Logger log = mock(Logger.class); + final String longBody = "x".repeat(1024); + WebClientResponseException ex = WebClientResponseException.create( + 500, "Internal Server Error", null, longBody.getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8); + + UpstreamErrorLogger.logUpstreamError(log, ex, "stream"); + + final String expectedTruncated = longBody.substring(0, 512) + + "...(truncated, total " + longBody.length() + " chars)"; + verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}", + "stream", HttpStatus.INTERNAL_SERVER_ERROR, expectedTruncated, ex); + } }