From 03b40be500bcf07f72d5637f96165cf4d4b7c23d Mon Sep 17 00:00:00 2001
From: eye-gu <734164350@qq.com>
Date: Wed, 13 May 2026 22:29:31 +0800
Subject: [PATCH 1/6] refactor: use OpenAiApi directly to return OpenAI Chat
Completions format
---
.../ai/common/config/AiCommonConfig.java | 2 +-
.../protocol/OpenAiProtocolAdapter.java | 118 +++++++++
.../ai/common/strategy/FallbackStrategy.java | 49 ----
.../strategy/SimpleModelFallbackStrategy.java | 68 -----
.../protocol/OpenAiProtocolAdapterTest.java | 238 ++++++++++++++++++
.../ai/proxy/enhanced/AiProxyPlugin.java | 163 ++++--------
.../proxy/enhanced/cache/ChatClientCache.java | 143 -----------
.../handler/AiProxyPluginHandler.java | 13 -
.../service/AiProxyExecutorService.java | 111 ++++----
.../enhanced/service/UpstreamErrorLogger.java | 53 ++++
.../CommonAiProxyApiKeyDataSubscriber.java | 8 -
.../ai/proxy/enhanced/AiProxyPluginTest.java | 105 ++++----
.../service/AiProxyExecutorServiceTest.java | 134 +++++-----
...CommonAiProxyApiKeyDataSubscriberTest.java | 69 +----
.../ai/proxy/AiProxyPluginConfiguration.java | 22 +-
15 files changed, 642 insertions(+), 654 deletions(-)
create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
delete mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java
delete mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java
create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
delete mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java
index c3d30221efe0..71a62e15636f 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java
@@ -50,7 +50,7 @@ public class AiCommonConfig {
/**
* temperature.
*/
- private Double temperature = 0.8;
+ private Double temperature;
/**
* max tokens.
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
new file mode 100644
index 000000000000..ce44a5d3178b
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.common.protocol;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import org.apache.shenyu.common.utils.JsonUtils;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
+
+import java.util.Objects;
+
+/**
+ * Adapts between OpenAI Chat Completions wire format and internal representations.
+ */
+public final class OpenAiProtocolAdapter {
+
+ private OpenAiProtocolAdapter() {
+ }
+
+ /**
+ * Resolve the stream flag: client request body takes priority,
+ * falls back to the provided default value.
+ *
+ * @param requestBody the raw JSON request body
+ * @param fallbackStream the default stream value from admin config
+ * @return true if streaming, false otherwise
+ */
+ public static boolean resolveStream(final String requestBody, final Boolean fallbackStream) {
+ if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
+ return Boolean.TRUE.equals(fallbackStream);
+ }
+ final JsonNode root = JsonUtils.toJsonNode(requestBody);
+ if (Objects.isNull(root)) {
+ return Boolean.TRUE.equals(fallbackStream);
+ }
+ if (root.hasNonNull("stream")) {
+ return root.get("stream").asBoolean();
+ }
+ return Boolean.TRUE.equals(fallbackStream);
+ }
+
+ /**
+ * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields
+ * including reasoning_content in assistant messages.
+ *
+ *
Spring AI's createRequest() loses reasoning_content, refusal, and annotations
+ * when reconstructing ChatCompletionMessage from AssistantMessage.
+ * This method avoids that loss by deserializing the raw JSON directly.
+ *
+ *
Also converts max_completion_tokens to max_tokens for broader API compatibility.
+ *
+ * @param requestBody the raw JSON request body in OpenAI Chat Completions format
+ * @param stream whether this is a streaming request (sets the stream field)
+ * @return a ChatCompletionRequest with all fields preserved from the original request
+ */
+ public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, final boolean stream) {
+ return toChatCompletionRequest(requestBody, stream, null);
+ }
+
+ /**
+ * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields.
+ * For model, temperature, max_tokens: client request takes priority;
+ * if missing, falls back to the corresponding field in fallbackConfig.
+ *
+ * @param requestBody the raw JSON request body in OpenAI Chat Completions format
+ * @param stream whether this is a streaming request
+ * @param fallbackConfig the admin config used as fallback when client omits fields
+ * @return a ChatCompletionRequest with all fields preserved
+ */
+ public static ChatCompletionRequest toChatCompletionRequest(final String requestBody,
+ final boolean stream, final AiCommonConfig fallbackConfig) {
+ if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
+ throw new IllegalArgumentException("Request body must not be empty");
+ }
+ final JsonNode root = JsonUtils.toJsonNode(requestBody);
+ if (Objects.isNull(root) || !root.isObject()) {
+ throw new IllegalArgumentException("Invalid request body: expected a JSON object");
+ }
+ final ObjectNode mutableRoot = (ObjectNode) root;
+
+ if (root.hasNonNull("max_completion_tokens") && !root.hasNonNull("max_tokens")) {
+ mutableRoot.put("max_tokens", root.get("max_completion_tokens").asInt());
+ mutableRoot.remove("max_completion_tokens");
+ }
+
+ if (Objects.nonNull(fallbackConfig)) {
+ if (!root.hasNonNull("model") && Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) {
+ mutableRoot.put("model", fallbackConfig.getModel());
+ }
+ if (!root.hasNonNull("temperature") && Objects.nonNull(fallbackConfig.getTemperature())) {
+ mutableRoot.put("temperature", fallbackConfig.getTemperature());
+ }
+ if (!root.hasNonNull("max_tokens") && Objects.nonNull(fallbackConfig.getMaxTokens())) {
+ mutableRoot.put("max_tokens", fallbackConfig.getMaxTokens());
+ }
+ }
+
+ mutableRoot.put("stream", stream);
+
+ return JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class);
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java
deleted file mode 100644
index 08c68c4f46ec..000000000000
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shenyu.plugin.ai.common.strategy;
-
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatResponse;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-
-/**
- * Fallback strategy.
- */
-public interface FallbackStrategy {
-
- /**
- * Execute the fallback strategy.
- *
- * @param fallbackClient the pre-configured and cached chat client for fallback
- * @param requestBody the original request body as a string
- * @param originalError the original error that triggered the fallback
- * @return a Mono containing the fallback ChatResponse
- */
- Mono fallback(ChatClient fallbackClient, String requestBody, Throwable originalError);
-
- /**
- * Execute the fallback strategy for stream.
- *
- * @param fallbackClient the pre-configured and cached chat client for fallback
- * @param requestBody the original request body as a string
- * @param originalError the original error that triggered the fallback
- * @return a Flux containing the fallback ChatResponse
- */
- Flux fallbackStream(ChatClient fallbackClient, String requestBody, Throwable originalError);
-}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java
deleted file mode 100644
index 94b3653e654d..000000000000
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shenyu.plugin.ai.common.strategy;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatResponse;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-import reactor.core.scheduler.Schedulers;
-
-/**
- * A fallback strategy that simply executes a call with a pre-configured client.
- */
-public final class SimpleModelFallbackStrategy implements FallbackStrategy {
-
- /**
- * The constant INSTANCE.
- */
- public static final FallbackStrategy INSTANCE = new SimpleModelFallbackStrategy();
-
- private static final Logger LOG = LoggerFactory.getLogger(SimpleModelFallbackStrategy.class);
-
- private SimpleModelFallbackStrategy() {
- }
-
- @Override
- public Mono fallback(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) {
- LOG.warn("Executing simple model fallback strategy due to error: {}", originalError.getMessage());
-
- return Mono.fromCallable(() -> fallbackClient.prompt()
- .user(requestBody)
- .call()
- .chatResponse())
- .subscribeOn(Schedulers.boundedElastic())
- .doOnSuccess(response -> LOG.info("Fallback call successful."))
- .doOnError(fallbackError -> LOG.error("Fallback call also failed.", fallbackError));
- }
-
- @Override
- public Flux fallbackStream(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) {
- LOG.warn("Executing simple model fallback stream strategy due to error: {}", originalError.getMessage());
-
- return Flux.defer(() -> fallbackClient.prompt()
- .user(requestBody)
- .stream()
- .chatResponse())
- .subscribeOn(Schedulers.boundedElastic())
- .doOnComplete(() -> LOG.info("Fallback stream completed successfully."))
- .doOnError(fallbackError -> LOG.error("Fallback stream also failed.", fallbackError));
- }
-}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
new file mode 100644
index 000000000000..566824318014
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.common.protocol;
+
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public final class OpenAiProtocolAdapterTest {
+
+ private static final String BASE_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}";
+
+ @Test
+ void testNullRequestBody() {
+ assertThrows(IllegalArgumentException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest(null, false));
+ }
+
+ @Test
+ void testEmptyRequestBody() {
+ assertThrows(IllegalArgumentException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest("", false));
+ }
+
+ @Test
+ void testStreamFlagSet() {
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true);
+ assertNotNull(req);
+ assertTrue(req.stream());
+ }
+
+ @Test
+ void testResolveStreamClientTrueOverridesFallbackFalse() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":true}";
+ assertTrue(OpenAiProtocolAdapter.resolveStream(body, false));
+ }
+
+ @Test
+ void testResolveStreamClientFalseOverridesFallbackTrue() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":false}";
+ assertFalse(OpenAiProtocolAdapter.resolveStream(body, true));
+ }
+
+ @Test
+ void testResolveStreamFallbackWhenClientMissing() {
+ assertTrue(OpenAiProtocolAdapter.resolveStream(BASE_BODY, true));
+ assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, false));
+ }
+
+ @Test
+ void testResolveStreamFallbackWhenClientMissingAndConfigNull() {
+ assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, null));
+ }
+
+ @Test
+ void testResolveStreamFallbackWhenBodyEmpty() {
+ assertTrue(OpenAiProtocolAdapter.resolveStream("", true));
+ assertFalse(OpenAiProtocolAdapter.resolveStream("", false));
+ }
+
+ @Test
+ void testStreamFlagUnset() {
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false);
+ assertNotNull(req);
+ assertFalse(req.stream());
+ }
+
+ @Test
+ void testMaxCompletionTokensConvertedToMaxTokens() {
+ final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":100}";
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false);
+ assertNotNull(req);
+ assertEquals(100, req.maxTokens());
+ }
+
+ @Test
+ void testClientModelTakesPriority() {
+ final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("admin-model");
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals("client-model", req.model());
+ }
+
+ @Test
+ void testFallbackModelWhenClientMissing() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("admin-model");
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertEquals("admin-model", req.model());
+ }
+
+ @Test
+ void testNoFallbackModelWhenConfigIsNull() {
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, null);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testNoFallbackWhenConfigModelIsNull() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel(null);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testNoFallbackWhenConfigModelIsEmpty() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("");
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testClientTemperatureTakesPriority() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.5}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setTemperature(0.9);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals(0.5, req.temperature(), 0.001);
+ }
+
+ @Test
+ void testFallbackTemperatureWhenClientMissing() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setTemperature(0.7);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertEquals(0.7, req.temperature(), 0.001);
+ }
+
+ @Test
+ void testNoFallbackTemperatureWhenConfigNull() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setTemperature(null);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testClientMaxTokensTakesPriority() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_tokens\":200}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setMaxTokens(500);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals(200, req.maxTokens());
+ }
+
+ @Test
+ void testFallbackMaxTokensWhenClientMissing() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setMaxTokens(500);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertEquals(500, req.maxTokens());
+ }
+
+ @Test
+ void testNoFallbackMaxTokensWhenConfigNull() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setMaxTokens(null);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testAllFallbackFieldsAppliedWhenClientMissingAll() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("admin-model");
+ config.setTemperature(0.3);
+ config.setMaxTokens(1024);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true, config);
+ assertEquals("admin-model", req.model());
+ assertEquals(0.3, req.temperature(), 0.001);
+ assertEquals(1024, req.maxTokens());
+ assertTrue(req.stream());
+ }
+
+ @Test
+ void testAllClientFieldsTakePriority() {
+ final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],"
+ + "\"temperature\":0.1,\"max_tokens\":50}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("admin-model");
+ config.setTemperature(0.9);
+ config.setMaxTokens(9999);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals("client-model", req.model());
+ assertEquals(0.1, req.temperature(), 0.001);
+ assertEquals(50, req.maxTokens());
+ }
+
+ @Test
+ void testPartialClientFieldsWithPartialFallback() {
+ final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.1}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("admin-model");
+ config.setTemperature(0.9);
+ config.setMaxTokens(2048);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals("client-model", req.model());
+ assertEquals(0.1, req.temperature(), 0.001);
+ assertEquals(2048, req.maxTokens());
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
index 687c51233f95..d9e72c2b67f5 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
@@ -21,25 +21,24 @@
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
-import org.apache.shenyu.common.enums.AiModelProviderEnum;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.utils.JsonUtils;
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
-import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
+import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.service.UpstreamErrorLogger;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
@@ -65,26 +64,18 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
*/
private static final long MAX_REQUEST_BODY_SIZE_BYTES = 5 * 1024 * 1024L;
- private final AiModelFactoryRegistry aiModelFactoryRegistry;
-
private final AiProxyConfigService aiProxyConfigService;
private final AiProxyExecutorService aiProxyExecutorService;
- private final ChatClientCache chatClientCache;
-
private final AiProxyPluginHandler aiProxyPluginHandler;
public AiProxyPlugin(
- final AiModelFactoryRegistry aiModelFactoryRegistry,
final AiProxyConfigService aiProxyConfigService,
final AiProxyExecutorService aiProxyExecutorService,
- final ChatClientCache chatClientCache,
final AiProxyPluginHandler aiProxyPluginHandler) {
- this.aiModelFactoryRegistry = aiModelFactoryRegistry;
this.aiProxyConfigService = aiProxyConfigService;
this.aiProxyExecutorService = aiProxyExecutorService;
- this.chatClientCache = chatClientCache;
this.aiProxyPluginHandler = aiProxyPluginHandler;
}
@@ -148,7 +139,7 @@ protected Mono doExecute(
}
}
- if (Boolean.TRUE.equals(primaryConfig.getStream())) {
+ if (OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream())) {
return handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
}
return handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
@@ -161,23 +152,26 @@ private Mono handleStreamRequest(
final String requestBody,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
- final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig);
- final String prompt = aiProxyConfigService.extractPrompt(requestBody);
- final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle,
- selector.getId(), requestBody);
+ final OpenAiApi mainApi = createOpenAiApi(primaryConfig);
+ final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig);
+ final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle,
+ requestBody);
final ServerHttpResponse response = exchange.getResponse();
response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM);
- final Flux chatResponseFlux = aiProxyExecutorService.executeStream(mainClient, fallbackClient,
- prompt);
+ final Flux chunkFlux = aiProxyExecutorService.executeDirectStream(
+ mainApi, fallbackApi, request);
- final Flux sseFlux = chatResponseFlux.map(
- chatResponse -> {
- final String json = JsonUtils.toJson(chatResponse);
+ final Flux sseFlux = chunkFlux.map(
+ chunk -> {
+ final String json = JsonUtils.toJson(chunk);
final String sseData = "data: " + json + "\n\n";
return response.bufferFactory()
.wrap(sseData.getBytes(StandardCharsets.UTF_8));
- });
+ }).concatWith(Mono.fromSupplier(() ->
+ response.bufferFactory()
+ .wrap("data: [DONE]\n\n".getBytes(StandardCharsets.UTF_8))))
+ .doOnError(e -> logUpstreamError(e, "stream"));
return response.writeWith(sseFlux);
}
@@ -188,24 +182,25 @@ private Mono handleNonStreamRequest(
final String requestBody,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
- final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig);
- final String prompt = aiProxyConfigService.extractPrompt(requestBody);
- final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle,
- selector.getId(), requestBody);
+ final OpenAiApi mainApi = createOpenAiApi(primaryConfig);
+ final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig);
+ final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle,
+ requestBody);
return aiProxyExecutorService
- .execute(mainClient, fallbackClient, prompt)
+ .executeDirectCall(mainApi, fallbackApi, request)
.flatMap(
- response -> {
- byte[] jsonBytes = JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8);
+ responseEntity -> {
+ final String responseJson = JsonUtils.toJson(responseEntity.getBody());
+ byte[] jsonBytes = responseJson.getBytes(StandardCharsets.UTF_8);
return WebFluxResultUtils.result(exchange, jsonBytes);
- });
+ })
+ .doOnError(e -> logUpstreamError(e, "non-stream"));
}
- private Optional resolveFallbackClient(
+ private Optional resolveFallbackOpenAiApi(
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle,
- final String selectorId,
final String requestBody) {
return aiProxyConfigService
.resolveDynamicFallbackConfig(primaryConfig, requestBody)
@@ -214,86 +209,30 @@ private Optional resolveFallbackClient(
if (LOG.isDebugEnabled()) {
LOG.debug("[AiProxy] dynamic fallback config: {}", cfg);
}
- return createDynamicFallbackClient(cfg);
+ return createOpenAiApi(cfg);
})
- .or(
- () -> aiProxyConfigService
- .resolveAdminFallbackConfig(primaryConfig, selectorHandle)
- .map(adminFallbackConfig -> {
- LOG.info("[AiProxy] use admin fallback");
- if (LOG.isDebugEnabled()) {
- LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig);
- }
- return createAdminFallbackClient(selectorId, adminFallbackConfig);
- }));
- }
-
- /**
- * Generate cache key based on config fields excluding apiKey.
- * This ensures cache consistency even when apiKey is updated at runtime.
- *
- * @param config the config
- * @return cache key hash
- */
- private int generateConfigCacheKey(final AiCommonConfig config) {
- return Objects.hash(
- config.getProvider(),
- config.getBaseUrl(),
- config.getModel(),
- config.getTemperature(),
- config.getMaxTokens(),
- config.getStream()
- // Explicitly exclude apiKey to avoid cache misses when apiKey changes
- );
- }
-
- private ChatClient createMainChatClient(final String selectorId, final AiCommonConfig config) {
- final int configHash = generateConfigCacheKey(config);
- final String cacheKey = selectorId + "|main_" + configHash;
- return chatClientCache.computeIfAbsent(
- cacheKey,
- () -> {
- LOG.info("Creating and caching main model for selector: {}, key: {}", selectorId, cacheKey);
- return createChatModel(config);
- });
- }
-
- private ChatClient createAdminFallbackClient(
- final String selectorId, final AiCommonConfig fallbackConfig) {
- final int configHash = generateConfigCacheKey(fallbackConfig);
- final String fallbackCacheKey = selectorId + "|adminFallback_" + configHash;
- return chatClientCache.computeIfAbsent(
- fallbackCacheKey,
- () -> {
- LOG.info(
- "Creating and caching admin fallback model for selector: {}, key: {}",
- selectorId, fallbackCacheKey);
- return createChatModel(fallbackConfig);
- });
- }
-
- private ChatClient createDynamicFallbackClient(final AiCommonConfig fallbackConfig) {
- LOG.info("Creating non-cached dynamic fallback model.");
- return ChatClient.builder(createChatModel(fallbackConfig)).build();
+ .or(() -> aiProxyConfigService
+ .resolveAdminFallbackConfig(primaryConfig, selectorHandle)
+ .map(adminFallbackConfig -> {
+ LOG.info("[AiProxy] use admin fallback");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig);
+ }
+ return createOpenAiApi(adminFallbackConfig);
+ }));
}
- private ChatModel createChatModel(final AiCommonConfig config) {
- if (LOG.isDebugEnabled()) {
- LOG.debug("Creating chat model with config: {}", config);
- }
- final AiModelProviderEnum provider = AiModelProviderEnum.getByName(config.getProvider());
- if (Objects.isNull(provider)) {
- throw new IllegalArgumentException(
- "Invalid AI model provider in config: " + config.getProvider());
+ private OpenAiApi createOpenAiApi(final AiCommonConfig config) {
+ if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) {
+ throw new IllegalArgumentException("baseUrl must not be empty");
}
- final var factory = aiModelFactoryRegistry.getFactory(provider);
- if (Objects.isNull(factory)) {
- throw new IllegalArgumentException(
- "AI model factory not found for provider: " + provider.getName());
+ if (Objects.isNull(config.getApiKey()) || config.getApiKey().isEmpty()) {
+ throw new IllegalArgumentException("apiKey must not be empty");
}
- return Objects.requireNonNull(
- factory.createAiModel(config),
- "The AI model created by the factory must not be null");
+ return OpenAiApi.builder()
+ .baseUrl(config.getBaseUrl())
+ .apiKey(config.getApiKey())
+ .build();
}
@Override
@@ -305,4 +244,8 @@ public int getOrder() {
public String named() {
return PluginEnum.AI_PROXY.getName();
}
+
+ private void logUpstreamError(final Throwable e, final String mode) {
+ UpstreamErrorLogger.logUpstreamError(LOG, e, mode);
+ }
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
deleted file mode 100644
index 473ace98fa91..000000000000
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shenyu.plugin.ai.proxy.enhanced.cache;
-
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.Map;
-import java.util.Objects;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Supplier;
-
-/**
- * This is ChatClient cache.
- */
-public final class ChatClientCache {
-
- private static final Logger LOG = LoggerFactory.getLogger(ChatClientCache.class);
-
- private static final int MAX_CACHE_SIZE = getCacheSize();
-
- private final Map chatClientMap = new ConcurrentHashMap<>();
-
- private final AtomicBoolean evictionInProgress = new AtomicBoolean(false);
-
- /**
- * Instantiates a new Chat client cache.
- */
- public ChatClientCache() {
- }
-
- private static int getCacheSize() {
- String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize",
- System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE"));
- if (Objects.nonNull(value)) {
- try {
- return Integer.parseInt(value);
- } catch (NumberFormatException e) {
- LoggerFactory.getLogger(ChatClientCache.class)
- .warn("[ChatClientCache] Invalid cache size '{}', using default 500.", value);
- }
- }
- return 500;
- }
-
- /**
- * Gets client or compute if absent.
- *
- * @param key the key
- * @param chatModelSupplier the chat model supplier
- * @return the chat client
- */
- public ChatClient computeIfAbsent(final String key, final Supplier chatModelSupplier) {
- // Check size before computing, but use synchronized block to prevent race conditions
- final int currentSize = chatClientMap.size();
- if (currentSize > MAX_CACHE_SIZE) {
- // Use atomic flag to ensure only one thread performs eviction
- if (evictionInProgress.compareAndSet(false, true)) {
- try {
- synchronized (chatClientMap) {
- // Double-check after acquiring lock
- if (chatClientMap.size() > MAX_CACHE_SIZE) {
- evictOldestEntries();
- }
- }
- } finally {
- evictionInProgress.set(false);
- }
- }
- }
- return chatClientMap.computeIfAbsent(key, k -> ChatClient.builder(chatModelSupplier.get()).build());
- }
-
- /**
- * Evict oldest entries when cache size exceeds limit.
- * Removes approximately 25% of entries to avoid thundering herd problem.
- */
- private void evictOldestEntries() {
- final int currentSize = chatClientMap.size();
- if (currentSize <= MAX_CACHE_SIZE) {
- return;
- }
-
- // Evict 25% of entries, but at least 10 entries
- final int evictCount = Math.max(10, currentSize / 4);
- LOG.warn("[ChatClientCache] Cache size {} exceeded limit {}, evicting {} oldest entries",
- currentSize, MAX_CACHE_SIZE, evictCount);
-
- // Since ConcurrentHashMap doesn't maintain insertion order,
- // we evict entries based on iteration order (which is somewhat arbitrary but better than clearing all)
- int removed = 0;
- for (final String key : chatClientMap.keySet()) {
- if (removed >= evictCount) {
- break;
- }
- chatClientMap.remove(key);
- removed++;
- }
-
- LOG.info("[ChatClientCache] Evicted {} entries, cache size now: {}", removed, chatClientMap.size());
- }
-
- /**
- * Removes all cached clients associated with a selector ID (by prefix matching
- * "selectorId|").
- *
- * @param selectorId the selector id
- */
- public void remove(final String selectorId) {
- if (java.util.Objects.isNull(selectorId)) {
- return;
- }
- final String prefix = selectorId + "|";
- chatClientMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix));
- LOG.info("[ChatClientCache] invalidate selectorId={} (by prefix)", selectorId);
- }
-
- /**
- * Clear all cached clients.
- */
- public void clearAll() {
- chatClientMap.clear();
- LOG.info("[ChatClientCache] cleared all cached clients");
- }
-}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
index e4e3f470ed7c..8b8868a04560 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
@@ -23,7 +23,6 @@
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.utils.GsonUtils;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.plugin.base.cache.CommonHandleCache;
import org.apache.shenyu.plugin.base.handler.PluginDataHandler;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
@@ -37,12 +36,6 @@ public class AiProxyPluginHandler implements PluginDataHandler {
private final CommonHandleCache selectorCachedHandle = new CommonHandleCache<>();
- private final ChatClientCache chatClientCache;
-
- public AiProxyPluginHandler(final ChatClientCache chatClientCache) {
- this.chatClientCache = chatClientCache;
- }
-
@Override
public void handlerPlugin(final PluginData pluginData) {
// Note: The logic for handling global plugin configuration with Singleton has been removed
@@ -52,10 +45,6 @@ public void handlerPlugin(final PluginData pluginData) {
@Override
public void handlerSelector(final SelectorData selectorData) {
- // Invalidate the cache first when the selector is updated.
- chatClientCache.remove(selectorData.getId());
- // Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events
- // with refreshed realApiKey after selector changes. Removing here introduces a window of misses.
if (Objects.isNull(selectorData.getHandle())) {
return;
}
@@ -67,8 +56,6 @@ public void handlerSelector(final SelectorData selectorData) {
@Override
public void removeSelector(final SelectorData selectorData) {
- // Invalidate the cache when the selector is removed.
- chatClientCache.remove(selectorData.getId());
selectorCachedHandle
.removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE));
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
index 71ab0c4c49d1..8502a772f43e 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
@@ -17,12 +17,14 @@
package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
-import org.apache.shenyu.plugin.ai.common.strategy.SimpleModelFallbackStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.retry.NonTransientAiException;
+import org.springframework.http.ResponseEntity;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
@@ -39,78 +41,77 @@ public class AiProxyExecutorService {
private static final Logger LOG = LoggerFactory.getLogger(AiProxyExecutorService.class);
/**
- * Execute the AI call with retry and fallback.
+ * Execute a streaming AI call directly via {@link OpenAiApi}, bypassing Spring AI's
+ * {@code createRequest()} which loses fields like {@code reasoning_content}.
*
- * @param mainClient the main chat client
- * @param fallbackClientOpt the optional fallback chat client
- * @param requestBody the request body
- * @return a Mono containing the ChatResponse
+ * @param mainApi the main OpenAiApi
+ * @param fallbackApiOpt the optional fallback OpenAiApi
+ * @param request the ChatCompletionRequest with all fields preserved
+ * @return a Flux of ChatCompletionChunk
*/
- public Mono execute(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) {
- final Mono mainCall = doChatCall(mainClient, requestBody);
-
- return mainCall
- .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))
- .filter(throwable -> !(throwable instanceof NonTransientAiException))
+ public Flux executeDirectStream(final OpenAiApi mainApi,
+ final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ return mainApi.chatCompletionStream(request)
+ .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct stream"))
+ .retryWhen(Retry.max(1)
.onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> {
- LOG.warn("Retries exhausted for AI call after {} attempts.",
- retrySignal.totalRetries(), retrySignal.failure());
- return new NonTransientAiException("Retries exhausted. Triggering fallback.",
+ LOG.warn("Direct stream retry exhausted. Triggering fallback.",
+ retrySignal.failure());
+ return new NonTransientAiException(
+ "Direct stream failed after 1 retry. Triggering fallback.",
retrySignal.failure());
}))
.onErrorResume(NonTransientAiException.class,
- throwable -> handleFallback(throwable, fallbackClientOpt, requestBody));
- }
-
- protected Mono doChatCall(final ChatClient client, final String requestBody) {
- return Mono.fromCallable(() -> client.prompt().user(requestBody).call().chatResponse())
- .subscribeOn(Schedulers.boundedElastic());
- }
-
- private Mono handleFallback(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) {
- LOG.warn("AI main call failed or retries exhausted, attempting to fallback...", throwable);
-
- if (fallbackClientOpt.isEmpty()) {
- return Mono.error(throwable);
- }
-
- return SimpleModelFallbackStrategy.INSTANCE.fallback(fallbackClientOpt.get(), requestBody, throwable);
+ throwable -> handleDirectFallbackStream(throwable, fallbackApiOpt, request));
}
/**
- * Execute the AI call with retry and fallback.
+ * Execute a non-streaming AI call directly via {@link OpenAiApi}.
*
- * @param mainClient the main chat client
- * @param fallbackClientOpt the optional fallback chat client
- * @param requestBody the request body
- * @return a Flux containing the ChatResponse
+ * @param mainApi the main OpenAiApi
+ * @param fallbackApiOpt the optional fallback OpenAiApi
+ * @param request the ChatCompletionRequest with all fields preserved
+ * @return a Mono of ResponseEntity containing ChatCompletion
*/
- public Flux executeStream(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) {
- final Flux mainStream = doChatStream(mainClient, requestBody);
-
- return mainStream
- .retryWhen(Retry.max(1)
+ public Mono> executeDirectCall(final OpenAiApi mainApi,
+ final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ return Mono.fromCallable(() -> mainApi.chatCompletionEntity(request))
+ .subscribeOn(Schedulers.boundedElastic())
+ .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct call"))
+ .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))
+ .filter(throwable -> !(throwable instanceof NonTransientAiException))
.onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> {
- LOG.warn("Retrying stream once failed. Attempts: {}. Triggering fallback.",
+ LOG.warn("Direct call retries exhausted after {} attempts. Triggering fallback.",
retrySignal.totalRetries(), retrySignal.failure());
- return new NonTransientAiException("Stream failed after 1 retry. Triggering fallback.", retrySignal.failure());
+ return new NonTransientAiException("Direct call retries exhausted. Triggering fallback.",
+ retrySignal.failure());
}))
.onErrorResume(NonTransientAiException.class,
- throwable -> handleFallbackStream(throwable, fallbackClientOpt, requestBody));
+ throwable -> handleDirectFallbackCall(throwable, fallbackApiOpt, request));
}
- protected Flux doChatStream(final ChatClient client, final String requestBody) {
- return Flux.defer(() -> client.prompt().user(requestBody).stream().chatResponse())
- .subscribeOn(Schedulers.boundedElastic());
+ private Flux handleDirectFallbackStream(final Throwable throwable,
+ final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ LOG.warn("Main direct stream failed, attempting fallback...", throwable);
+
+ if (fallbackApiOpt.isEmpty()) {
+ return Flux.error(throwable);
+ }
+
+ LOG.info("Using fallback OpenAiApi for direct stream");
+ return fallbackApiOpt.get().chatCompletionStream(request);
}
- private Flux handleFallbackStream(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) {
- LOG.warn("AI main stream failed or retries exhausted, attempting to fallback...", throwable);
+ private Mono> handleDirectFallbackCall(final Throwable throwable,
+ final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ LOG.warn("Main direct call failed, attempting fallback...", throwable);
- if (fallbackClientOpt.isEmpty()) {
- return Flux.error(throwable);
+ if (fallbackApiOpt.isEmpty()) {
+ return Mono.error(throwable);
}
- return SimpleModelFallbackStrategy.INSTANCE.fallbackStream(fallbackClientOpt.get(), requestBody, throwable);
+ LOG.info("Using fallback OpenAiApi for direct call");
+ return Mono.fromCallable(() -> fallbackApiOpt.get().chatCompletionEntity(request))
+ .subscribeOn(Schedulers.boundedElastic());
}
-}
\ No newline at end of file
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
new file mode 100644
index 000000000000..bee49aab7d0c
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+
+import org.slf4j.Logger;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+
+import java.util.Objects;
+
+/**
+ * Shared utility for logging upstream AI service errors with WebClientResponseException details.
+ */
+public final class UpstreamErrorLogger {
+
+ private UpstreamErrorLogger() {
+ }
+
+ public static void logUpstreamError(final Logger log, final Throwable e, final String mode) {
+ final WebClientResponseException webClientEx = findWebClientResponseException(e);
+ if (Objects.nonNull(webClientEx)) {
+ log.error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ mode, webClientEx.getStatusCode(), webClientEx.getResponseBodyAsString(), e);
+ } else {
+ log.error("[AiProxy] {} failed", mode, e);
+ }
+ }
+
+ private static WebClientResponseException findWebClientResponseException(final Throwable e) {
+ Throwable current = e;
+ while (Objects.nonNull(current)) {
+ if (current instanceof WebClientResponseException ex) {
+ return ex;
+ }
+ current = current.getCause();
+ }
+ return null;
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
index 4f63f5b53bb8..d25434881658 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
@@ -19,7 +19,6 @@
import org.apache.shenyu.common.dto.ProxyApiKeyData;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.sync.data.api.AiProxyApiKeyDataSubscriber;
import java.util.Objects;
@@ -29,12 +28,6 @@
*/
public final class CommonAiProxyApiKeyDataSubscriber implements AiProxyApiKeyDataSubscriber {
- private final ChatClientCache chatClientCache;
-
- public CommonAiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) {
- this.chatClientCache = chatClientCache;
- }
-
@Override
public void onSubscribe(final ProxyApiKeyData data) {
if (Objects.isNull(data) || Objects.isNull(data.getProxyApiKey())) {
@@ -54,6 +47,5 @@ public void unSubscribe(final ProxyApiKeyData data) {
@Override
public void refresh() {
AiProxyApiKeyCache.getInstance().refresh();
- chatClientCache.clearAll();
}
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
index 30e84abc0809..dfbf95a1b0c9 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
@@ -23,10 +23,7 @@
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
import org.apache.shenyu.common.enums.AiModelProviderEnum;
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
-import org.apache.shenyu.plugin.ai.common.spring.ai.AiModelFactory;
-import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
@@ -43,21 +40,22 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import java.util.Optional;
import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -75,27 +73,12 @@ public class AiProxyPluginTest {
private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}";
- @Mock
- private AiModelFactoryRegistry registry;
-
@Mock
private AiProxyConfigService configService;
@Mock
private AiProxyExecutorService executorService;
- @Mock
- private ChatClientCache chatClientCache;
-
- @Mock
- private AiModelFactory modelFactory;
-
- @Mock
- private ChatModel chatModel;
-
- @Mock
- private ChatClient chatClient;
-
private AiProxyPluginHandler aiProxyPluginHandler;
private AiProxyPlugin plugin;
@@ -110,8 +93,8 @@ public class AiProxyPluginTest {
@BeforeEach
public void setUp() {
- aiProxyPluginHandler = new AiProxyPluginHandler(chatClientCache);
- plugin = new AiProxyPlugin(registry, configService, executorService, chatClientCache, aiProxyPluginHandler);
+ aiProxyPluginHandler = new AiProxyPluginHandler();
+ plugin = new AiProxyPlugin(configService, executorService, aiProxyPluginHandler);
selector = new SelectorData();
selector.setId(SELECTOR_ID);
@@ -129,14 +112,6 @@ public void setUp() {
when(applicationContext.getBean(ShenyuResult.class)).thenReturn(resultMock);
SpringBeanUtils.getInstance().setApplicationContext(applicationContext);
- // Common mock behavior for model creation and cache
- when(registry.getFactory(any(AiModelProviderEnum.class))).thenReturn(modelFactory);
- when(modelFactory.createAiModel(any(AiCommonConfig.class))).thenReturn(chatModel);
- when(chatClientCache.computeIfAbsent(anyString(), any())).thenAnswer(invocation -> {
- // Always return the mock ChatClient to avoid any UnsupportedOperationException
- return chatClient;
- });
-
// mock static
apiKeyCacheMockedStatic = mockStatic(AiProxyApiKeyCache.class);
}
@@ -150,13 +125,14 @@ public void tearDown() {
private void setupSuccessMocks(final AiProxyHandle handle, final AiCommonConfig primaryConfig, final Optional fallbackConfig) {
aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle);
- final ChatResponse chatResponse = mock(ChatResponse.class);
+ final ChatCompletion chatCompletion = mock(ChatCompletion.class);
+ final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion);
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(fallbackConfig);
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(fallbackConfig);
- when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0));
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity));
+ when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Flux.empty());
}
@Test
@@ -164,6 +140,8 @@ public void testExecuteSuccess() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
@@ -173,7 +151,7 @@ public void testExecuteSuccess() {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).execute(any(), any(), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
}
@Test
@@ -181,8 +159,12 @@ public void testExecuteWithDynamicFallback() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final AiCommonConfig fallbackConfig = new AiCommonConfig();
fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName());
+ fallbackConfig.setBaseUrl("https://api.deepseek.com");
+ fallbackConfig.setApiKey("fallback-key");
setupSuccessMocks(handle, primaryConfig, Optional.of(fallbackConfig));
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.of(fallbackConfig));
@@ -190,7 +172,7 @@ public void testExecuteWithDynamicFallback() {
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).execute(any(ChatClient.class), any(Optional.class), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
}
@Test
@@ -199,6 +181,7 @@ public void testExecuteWithValidProxyApiKey() {
handle.setProxyEnabled("true");
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
primaryConfig.setApiKey("original-key");
// setup request with proxy key
@@ -225,6 +208,8 @@ public void testExecuteWithInvalidProxyApiKey() {
handle.setProxyEnabled("true");
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
// setup request with proxy key
exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/test").header(Constants.X_API_KEY, "proxy-key-invalid").body(REQUEST_BODY));
@@ -252,8 +237,12 @@ public void testExecuteWithAdminFallback() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final AiCommonConfig fallbackConfig = new AiCommonConfig();
fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName());
+ fallbackConfig.setBaseUrl("https://api.deepseek.com");
+ fallbackConfig.setApiKey("fallback-key");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty());
@@ -262,7 +251,7 @@ public void testExecuteWithAdminFallback() {
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).execute(any(ChatClient.class), any(Optional.class), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
}
@Test
@@ -270,53 +259,49 @@ public void testCacheIsUsedForAdminFallbackClient() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final AiCommonConfig fallbackConfig = new AiCommonConfig();
fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName());
+ fallbackConfig.setBaseUrl("https://api.deepseek.com");
+ fallbackConfig.setApiKey("fallback-key");
// Cache the handle for the test
aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle);
- final ChatResponse chatResponse = mock(ChatResponse.class);
+ final ChatCompletion chatCompletion = mock(ChatCompletion.class);
+ final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion);
// Setup all necessary mocks
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty());
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig));
- when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0));
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity));
// Execute the test - focus on successful execution rather than cache verification
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
-
+
// Verify that the configuration methods were called correctly
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).execute(any(), any(), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
}
@Test
- public void testCreateChatModelThrowsException() {
+ public void testCreateOpenAiApiThrowsException() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
- primaryConfig.setProvider("InvalidProvider");
-
- // Cache the handle for the test
+ primaryConfig.setBaseUrl(null);
+
aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle);
-
- // Mock config service to return the invalid config
+
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty());
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.empty());
- when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0));
-
- // Mock registry to return null factory for invalid provider - this should cause IllegalArgumentException
- when(registry.getFactory(any())).thenReturn(null);
-
- // Mock executorService to return a proper Mono to avoid NullPointerException
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(new IllegalArgumentException("AI model factory not found")));
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
- .expectError(IllegalArgumentException.class)
+ .expectErrorMatches(e -> e instanceof IllegalArgumentException
+ && e.getMessage().contains("baseUrl must not be empty"))
.verify();
}
@@ -325,10 +310,12 @@ public void testExecutorServiceError() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final RuntimeException exception = new RuntimeException("AI execution failed");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(exception));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.error(exception));
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.expectErrorMatches(exception::equals)
@@ -337,6 +324,6 @@ public void testExecutorServiceError() {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).execute(any(), any(), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
}
}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
index 7bd2bbbdac0d..46d3719cab2a 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
@@ -20,21 +20,19 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
-import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.retry.NonTransientAiException;
-import org.springframework.web.client.RestClientException;
+import org.springframework.http.ResponseEntity;
+import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
import java.util.Optional;
-import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -42,102 +40,96 @@
@ExtendWith(MockitoExtension.class)
public class AiProxyExecutorServiceTest {
- @Mock
- private ChatModel mainChatModel;
-
- @Mock
- private ChatModel fallbackChatModel;
-
- private ChatClient mainClient;
-
- private ChatClient fallbackClient;
-
private AiProxyExecutorService executorService;
@BeforeEach
void setUp() {
- executorService = spy(new AiProxyExecutorService());
- mainClient = ChatClient.create(mainChatModel);
- fallbackClient = ChatClient.create(fallbackChatModel);
+ executorService = new AiProxyExecutorService();
}
@Test
- void testExecuteSuccessOnFirstAttempt() {
- final ChatResponse successResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class))).thenReturn(successResponse);
-
- StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request"))
- .expectNext(successResponse)
+ void testExecuteDirectStreamSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionChunk chunk = mock(ChatCompletionChunk.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ when(mainApi.chatCompletionStream(request)).thenReturn(Flux.just(chunk));
+
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request))
+ .expectNext(chunk)
.verifyComplete();
- verify(mainChatModel, times(1)).call(any(Prompt.class));
+ verify(mainApi, times(1)).chatCompletionStream(request);
}
@Test
- void testExecuteRetryOnRestClientExceptionAndSucceed() {
- final ChatResponse successResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class)))
- .thenThrow(new RestClientException("transient error"))
- .thenReturn(successResponse);
-
- StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request"))
- .expectNext(successResponse)
- .verifyComplete();
+ void testExecuteDirectStreamErrorWithNoFallback() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error")));
- verify(mainChatModel, times(2)).call(any(Prompt.class));
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request))
+ .expectError(NonTransientAiException.class)
+ .verify();
}
@Test
- void testExecuteRetryExhaustedThenFallback() {
- final ChatResponse fallbackResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class))).thenThrow(new RestClientException("transient error"));
- when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse);
+ void testExecuteDirectStreamFallbackSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final OpenAiApi fallbackApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ final ChatCompletionChunk fallbackChunk = mock(ChatCompletionChunk.class);
- StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request"))
- .expectNext(fallbackResponse)
+ when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error")));
+ when(fallbackApi.chatCompletionStream(request)).thenReturn(Flux.just(fallbackChunk));
+
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(fallbackApi), request))
+ .expectNext(fallbackChunk)
.verifyComplete();
- verify(mainChatModel, times(4)).call(any(Prompt.class));
- verify(fallbackChatModel, times(1)).call(any(Prompt.class));
+ verify(fallbackApi, times(1)).chatCompletionStream(request);
}
@Test
- void testExecuteNonTransientExceptionTriggersFallbackDirectly() {
- final ChatResponse fallbackResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error"));
- when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse);
-
- StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request"))
- .expectNext(fallbackResponse)
+ void testExecuteDirectCallSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ final ChatCompletion completion = mock(ChatCompletion.class);
+ final ResponseEntity responseEntity = ResponseEntity.ok(completion);
+ when(mainApi.chatCompletionEntity(request)).thenReturn(responseEntity);
+
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request))
+ .expectNext(responseEntity)
.verifyComplete();
- verify(mainChatModel, times(1)).call(any(Prompt.class));
- verify(fallbackChatModel, times(1)).call(any(Prompt.class));
+ verify(mainApi, times(1)).chatCompletionEntity(request);
}
@Test
- void testExecuteFallbackFails() {
- final RestClientException fallbackException = new RestClientException("fallback failed");
- when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error"));
- when(fallbackChatModel.call(any(Prompt.class))).thenThrow(fallbackException);
+ void testExecuteDirectCallErrorWithNoFallback() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error"));
- StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request"))
- .expectErrorMatches(e -> e == fallbackException)
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request))
+ .expectError(NonTransientAiException.class)
.verify();
-
- verify(mainChatModel, times(1)).call(any(Prompt.class));
- verify(fallbackChatModel, times(1)).call(any(Prompt.class));
}
@Test
- void testExecuteNoFallbackProvided() {
- final NonTransientAiException mainException = new NonTransientAiException("non-transient error");
- when(mainChatModel.call(any(Prompt.class))).thenThrow(mainException);
+ void testExecuteDirectCallFallbackSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final OpenAiApi fallbackApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ final ChatCompletion fallbackCompletion = mock(ChatCompletion.class);
+ final ResponseEntity fallbackResponse = ResponseEntity.ok(fallbackCompletion);
- StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request"))
- .expectErrorMatches(e -> e == mainException)
- .verify();
+ when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error"));
+ when(fallbackApi.chatCompletionEntity(request)).thenReturn(fallbackResponse);
+
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(fallbackApi), request))
+ .expectNext(fallbackResponse)
+ .verifyComplete();
- verify(mainChatModel, times(1)).call(any(Prompt.class));
+ verify(fallbackApi, times(1)).chatCompletionEntity(request);
}
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java
index c863ba20f48a..4827c5337240 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java
@@ -19,17 +19,12 @@
import org.apache.shenyu.common.dto.ProxyApiKeyData;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-import static org.mockito.Mockito.mock;
-
class CommonAiProxyApiKeyDataSubscriberTest {
- private final ChatClientCache chatClientCache = mock(ChatClientCache.class);
-
@AfterEach
void cleanup() {
AiProxyApiKeyCache.getInstance().refresh();
@@ -37,7 +32,7 @@ void cleanup() {
@Test
void testOnSubscribeCachesEnabled() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-sub")
@@ -51,7 +46,7 @@ void testOnSubscribeCachesEnabled() {
@Test
void testOnSubscribeIgnoresDisabled() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-disabled")
@@ -65,7 +60,7 @@ void testOnSubscribeIgnoresDisabled() {
@Test
void testUnSubscribeRemoves() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-unsub")
@@ -81,7 +76,7 @@ void testUnSubscribeRemoves() {
@Test
void testRefreshClearsCache() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-refresh")
@@ -98,14 +93,14 @@ void testRefreshClearsCache() {
@Test
void testOnSubscribeNullDataNoThrow() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
subscriber.onSubscribe(null);
Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size());
}
@Test
void testOnSubscribeNullKeyIgnored() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey(null)
@@ -119,14 +114,14 @@ void testOnSubscribeNullKeyIgnored() {
@Test
void testUnSubscribeNullDataNoThrow() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
subscriber.unSubscribe(null);
Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size());
}
@Test
void testUnSubscribeNullKeyNoThrow() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey(null)
@@ -140,7 +135,7 @@ void testUnSubscribeNullKeyNoThrow() {
@Test
void testDuplicateSubscribeOverrides() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData v1 = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("dup-key")
@@ -161,48 +156,4 @@ void testDuplicateSubscribeOverrides() {
subscriber.onSubscribe(v2);
Assertions.assertEquals("real-2", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "dup-key"));
}
-
- @Test
- void testSubscribeDisabledDoesNotOverrideExisting() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
- ProxyApiKeyData enabled = ProxyApiKeyData.builder()
- .selectorId("test-selector")
- .proxyApiKey("keep-key")
- .realApiKey("real-keep")
- .enabled(Boolean.TRUE)
- .namespaceId("default")
- .build();
- subscriber.onSubscribe(enabled);
- Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key"));
-
- ProxyApiKeyData disabled = ProxyApiKeyData.builder()
- .selectorId("test-selector")
- .proxyApiKey("keep-key")
- .realApiKey("real-new")
- .enabled(Boolean.FALSE)
- .namespaceId("default")
- .build();
- subscriber.onSubscribe(disabled);
- // disabled subscribe should not override existing cached mapping
- Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key"));
- }
-
- @Test
- void testVeryLongProxyKey() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
- StringBuilder sb = new StringBuilder();
- for (int i = 0; i < 300; i++) {
- sb.append('A' + (i % 26));
- }
- String longKey = sb.toString();
- ProxyApiKeyData data = ProxyApiKeyData.builder()
- .selectorId("test-selector")
- .proxyApiKey(longKey)
- .realApiKey("real-long")
- .enabled(Boolean.TRUE)
- .namespaceId("default")
- .build();
- subscriber.onSubscribe(data);
- Assertions.assertEquals("real-long", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", longKey));
- }
-}
\ No newline at end of file
+}
diff --git a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java
index d017cfabd8fc..2bcb70aa1dae 100644
--- a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java
+++ b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java
@@ -22,7 +22,6 @@
import org.apache.shenyu.plugin.ai.common.spring.ai.factory.OpenAiModelFactory;
import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
import org.apache.shenyu.plugin.ai.proxy.enhanced.AiProxyPlugin;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
@@ -48,42 +47,30 @@ public class AiProxyPluginConfiguration {
/**
* Ai proxy plugin.
*
- * @param aiModelFactoryRegistry the aiModelFactoryRegistry
* @param aiProxyConfigService the aiProxyConfigService
* @param aiProxyExecutorService the aiProxyExecutorService
- * @param chatClientCache the chatClientCache
* @param aiProxyPluginHandler the aiProxyPluginHandler
* @return the shenyu plugin
*/
@Bean
public ShenyuPlugin aiProxyPlugin(
- final AiModelFactoryRegistry aiModelFactoryRegistry,
final AiProxyConfigService aiProxyConfigService,
final AiProxyExecutorService aiProxyExecutorService,
- final ChatClientCache chatClientCache,
final AiProxyPluginHandler aiProxyPluginHandler) {
return new AiProxyPlugin(
- aiModelFactoryRegistry,
aiProxyConfigService,
aiProxyExecutorService,
- chatClientCache,
aiProxyPluginHandler);
}
/**
* Ai proxy plugin handler.
*
- * @param chatClientCache the chatClientCache
* @return the shenyu plugin handler
*/
@Bean
- public AiProxyPluginHandler aiProxyPluginHandler(final ChatClientCache chatClientCache) {
- return new AiProxyPluginHandler(chatClientCache);
- }
-
- @Bean
- public ChatClientCache chatClientCache() {
- return new ChatClientCache();
+ public AiProxyPluginHandler aiProxyPluginHandler() {
+ return new AiProxyPluginHandler();
}
@Bean
@@ -131,11 +118,10 @@ public DeepSeekModelFactory deepSeekModelFactory() {
/**
* Ai proxy api key data subscriber.
*
- * @param chatClientCache the chatClientCache
* @return the subscriber
*/
@Bean
- public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) {
- return new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber() {
+ return new CommonAiProxyApiKeyDataSubscriber();
}
}
From 69d8ff6b394b102c0264c9eb45245789df9936f9 Mon Sep 17 00:00:00 2001
From: eye-gu <734164350@qq.com>
Date: Wed, 13 May 2026 23:28:18 +0800
Subject: [PATCH 2/6] add test
---
.../handler/AiProxyPluginHandlerTest.java | 119 ++++++++++++++++++
.../service/UpstreamErrorLoggerTest.java | 73 +++++++++++
2 files changed, 192 insertions(+)
create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
new file mode 100644
index 000000000000..248ae16ac883
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.handler;
+
+import org.apache.shenyu.common.constant.Constants;
+import org.apache.shenyu.common.dto.SelectorData;
+import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
+import org.apache.shenyu.common.enums.PluginEnum;
+import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+class AiProxyPluginHandlerTest {
+
+ private AiProxyPluginHandler handler;
+
+ @BeforeEach
+ void setUp() {
+ handler = new AiProxyPluginHandler();
+ }
+
+ @Test
+ void testPluginNamed() {
+ assertEquals(PluginEnum.AI_PROXY.getName(), handler.pluginNamed());
+ }
+
+ @Test
+ void testHandlerSelectorCachesHandle() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-1");
+ selector.setHandle("{\"provider\":\"open_ai\",\"baseUrl\":\"https://api.openai.com\",\"apiKey\":\"sk-test\",\"model\":\"gpt-4\"}");
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-1", Constants.DEFAULT_RULE);
+ AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key);
+ assertNotNull(cached);
+ assertEquals("open_ai", cached.getProvider());
+ assertEquals("https://api.openai.com", cached.getBaseUrl());
+ assertEquals("sk-test", cached.getApiKey());
+ assertEquals("gpt-4", cached.getModel());
+ }
+
+ @Test
+ void testHandlerSelectorWithNullHandle() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-2");
+ selector.setHandle(null);
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-2", Constants.DEFAULT_RULE);
+ assertNull(handler.getSelectorCachedHandle().obtainHandle(key));
+ }
+
+ @Test
+ void testHandlerSelectorWithEmptyHandle() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-3");
+ selector.setHandle("");
+
+ String key = CacheKeyUtils.INST.getKey("sel-3", Constants.DEFAULT_RULE);
+ assertNull(handler.getSelectorCachedHandle().obtainHandle(key));
+ }
+
+ @Test
+ void testHandlerSelectorWithFallbackNormalize() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-4");
+ selector.setHandle("{\"provider\":\"open_ai\",\"fallbackEnabled\":\"true\",\"fallbackModel\":\"fallback-model\"}");
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-4", Constants.DEFAULT_RULE);
+ AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key);
+ assertNotNull(cached);
+ assertNotNull(cached.getFallbackConfig());
+ assertEquals("fallback-model", cached.getFallbackConfig().getModel());
+ }
+
+ @Test
+ void testRemoveSelector() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-5");
+ selector.setHandle("{\"provider\":\"open_ai\",\"model\":\"gpt-4\"}");
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-5", Constants.DEFAULT_RULE);
+ assertNotNull(handler.getSelectorCachedHandle().obtainHandle(key));
+
+ handler.removeSelector(selector);
+ assertNull(handler.getSelectorCachedHandle().obtainHandle(key));
+ }
+
+ @Test
+ void testHandlerPluginDoesNothing() {
+ handler.handlerPlugin(null);
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
new file mode 100644
index 000000000000..46d21f43c9a1
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.springframework.http.HttpStatus;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+class UpstreamErrorLoggerTest {
+
+ @Test
+ void testLogWithWebClientResponseException() {
+ Logger log = mock(Logger.class);
+ WebClientResponseException ex = WebClientResponseException.create(
+ 429, "Too Many Requests", null, "rate limited".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8);
+
+ UpstreamErrorLogger.logUpstreamError(log, ex, "stream");
+
+ verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ "stream", HttpStatus.TOO_MANY_REQUESTS, "rate limited", ex);
+ }
+
+ @Test
+ void testLogWithWrappedWebClientResponseException() {
+ Logger log = mock(Logger.class);
+ WebClientResponseException inner = WebClientResponseException.create(
+ 500, "Internal Server Error", null, "server error".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8);
+ RuntimeException wrapped = new RuntimeException("call failed", inner);
+
+ UpstreamErrorLogger.logUpstreamError(log, wrapped, "non-stream");
+
+ verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ "non-stream", HttpStatus.INTERNAL_SERVER_ERROR, "server error", wrapped);
+ }
+
+ @Test
+ void testLogWithGenericException() {
+ Logger log = mock(Logger.class);
+ RuntimeException ex = new RuntimeException("generic error");
+
+ UpstreamErrorLogger.logUpstreamError(log, ex, "stream");
+
+ verify(log).error("[AiProxy] {} failed", "stream", ex);
+ }
+
+ @Test
+ void testLogWithNullExceptionDoesNotThrow() {
+ Logger log = mock(Logger.class);
+ assertDoesNotThrow(() -> UpstreamErrorLogger.logUpstreamError(log, null, "stream"));
+ }
+}
From b3aa313d61745b53a9da31cff31c69838191a130 Mon Sep 17 00:00:00 2001
From: eye-gu <734164350@qq.com>
Date: Thu, 14 May 2026 18:19:51 +0800
Subject: [PATCH 3/6] use openai cache
---
.../src/main/resources/application.yml | 2 +-
.../ai/proxy/enhanced/AiProxyPlugin.java | 19 ++-
.../proxy/enhanced/cache/OpenAiApiCache.java | 147 ++++++++++++++++++
.../handler/AiProxyPluginHandler.java | 12 ++
.../CommonAiProxyApiKeyDataSubscriber.java | 2 +
5 files changed, 175 insertions(+), 7 deletions(-)
create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
diff --git a/shenyu-admin/src/main/resources/application.yml b/shenyu-admin/src/main/resources/application.yml
index 6d2c3af3bfe2..47b4ab7cf3e2 100755
--- a/shenyu-admin/src/main/resources/application.yml
+++ b/shenyu-admin/src/main/resources/application.yml
@@ -21,7 +21,7 @@ spring:
application:
name: shenyu-admin
profiles:
- active: h2
+ active: mysql
thymeleaf:
cache: true
encoding: utf-8
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
index d9e72c2b67f5..b1650f37f3f3 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
@@ -26,6 +26,7 @@
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
@@ -152,9 +153,9 @@ private Mono handleStreamRequest(
final String requestBody,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
- final OpenAiApi mainApi = createOpenAiApi(primaryConfig);
+ final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig);
final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig);
- final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle,
+ final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle,
requestBody);
final ServerHttpResponse response = exchange.getResponse();
response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM);
@@ -182,9 +183,9 @@ private Mono handleNonStreamRequest(
final String requestBody,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
- final OpenAiApi mainApi = createOpenAiApi(primaryConfig);
+ final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig);
final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig);
- final Optional fallbackApi = resolveFallbackOpenAiApi(primaryConfig, selectorHandle,
+ final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle,
requestBody);
return aiProxyExecutorService
@@ -199,6 +200,7 @@ private Mono handleNonStreamRequest(
}
private Optional resolveFallbackOpenAiApi(
+ final String selectorId,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle,
final String requestBody) {
@@ -209,7 +211,7 @@ private Optional resolveFallbackOpenAiApi(
if (LOG.isDebugEnabled()) {
LOG.debug("[AiProxy] dynamic fallback config: {}", cfg);
}
- return createOpenAiApi(cfg);
+ return getCachedOpenAiApi(selectorId, "dynamicFallback", cfg);
})
.or(() -> aiProxyConfigService
.resolveAdminFallbackConfig(primaryConfig, selectorHandle)
@@ -218,10 +220,15 @@ private Optional resolveFallbackOpenAiApi(
if (LOG.isDebugEnabled()) {
LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig);
}
- return createOpenAiApi(adminFallbackConfig);
+ return getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig);
}));
}
+ private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type, final AiCommonConfig config) {
+ final String cacheKey = selectorId + "|" + type + "_" + config.hashCode();
+ return OpenAiApiCache.getInstance().computeIfAbsent(cacheKey, () -> createOpenAiApi(config));
+ }
+
private OpenAiApi createOpenAiApi(final AiCommonConfig config) {
if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) {
throw new IllegalArgumentException("baseUrl must not be empty");
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
new file mode 100644
index 000000000000..886b746ad94b
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.cache;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.openai.api.OpenAiApi;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
+
+/**
+ * This is OpenAiApi cache.
+ * Caches OpenAiApi instances by config key to avoid recreating on every request.
+ */
+public final class OpenAiApiCache {
+
+ private static final Logger LOG = LoggerFactory.getLogger(OpenAiApiCache.class);
+
+ private static final OpenAiApiCache INSTANCE = new OpenAiApiCache();
+
+ private static final int MAX_CACHE_SIZE = getCacheSize();
+
+ private final Map openAiApiMap = new ConcurrentHashMap<>();
+
+ private final AtomicBoolean evictionInProgress = new AtomicBoolean(false);
+
+ /**
+ * Instantiates a new OpenAiApi cache.
+ */
+ public OpenAiApiCache() {
+ }
+
+ /**
+ * Gets instance.
+ *
+ * @return singleton
+ */
+ public static OpenAiApiCache getInstance() {
+ return INSTANCE;
+ }
+
+ private static int getCacheSize() {
+ String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize",
+ System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE"));
+ if (Objects.nonNull(value)) {
+ try {
+ return Integer.parseInt(value);
+ } catch (NumberFormatException e) {
+ LOG.warn("[OpenAiApiCache] Invalid cache size '{}', using default 500.", value);
+ }
+ }
+ return 500;
+ }
+
+ /**
+ * Gets OpenAiApi or compute if absent.
+ *
+ * @param key the cache key (typically selectorId|configHash)
+ * @param openAiApiSupplier the supplier to create OpenAiApi if absent
+ * @return the cached or newly created OpenAiApi
+ */
+ public OpenAiApi computeIfAbsent(final String key, final Supplier openAiApiSupplier) {
+ final int currentSize = openAiApiMap.size();
+ if (currentSize > MAX_CACHE_SIZE) {
+ if (evictionInProgress.compareAndSet(false, true)) {
+ try {
+ synchronized (openAiApiMap) {
+ if (openAiApiMap.size() > MAX_CACHE_SIZE) {
+ evictOldestEntries();
+ }
+ }
+ } finally {
+ evictionInProgress.set(false);
+ }
+ }
+ }
+ return openAiApiMap.computeIfAbsent(key, k -> openAiApiSupplier.get());
+ }
+
+ /**
+ * Evict oldest entries when cache size exceeds limit.
+ * Removes approximately 25% of entries to avoid thundering herd problem.
+ */
+ private void evictOldestEntries() {
+ final int currentSize = openAiApiMap.size();
+ if (currentSize <= MAX_CACHE_SIZE) {
+ return;
+ }
+
+ final int evictCount = Math.max(10, currentSize / 4);
+ LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} oldest entries",
+ currentSize, MAX_CACHE_SIZE, evictCount);
+
+ int removed = 0;
+ for (final String key : openAiApiMap.keySet()) {
+ if (removed >= evictCount) {
+ break;
+ }
+ openAiApiMap.remove(key);
+ removed++;
+ }
+
+ LOG.info("[OpenAiApiCache] Evicted {} entries, cache size now: {}", removed, openAiApiMap.size());
+ }
+
+ /**
+ * Removes all cached OpenAiApi instances associated with a selector ID
+ * (by prefix matching "selectorId|").
+ *
+ * @param selectorId the selector id
+ */
+ public void remove(final String selectorId) {
+ if (Objects.isNull(selectorId)) {
+ return;
+ }
+ final String prefix = selectorId + "|";
+ openAiApiMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix));
+ LOG.info("[OpenAiApiCache] invalidate selectorId={} (by prefix)", selectorId);
+ }
+
+ /**
+ * Clear all cached OpenAiApi instances.
+ */
+ public void clearAll() {
+ openAiApiMap.clear();
+ LOG.info("[OpenAiApiCache] cleared all cached OpenAiApi instances");
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
index 8b8868a04560..df9dc1a8cb44 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
@@ -23,6 +23,8 @@
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.utils.GsonUtils;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache;
import org.apache.shenyu.plugin.base.cache.CommonHandleCache;
import org.apache.shenyu.plugin.base.handler.PluginDataHandler;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
@@ -36,6 +38,9 @@ public class AiProxyPluginHandler implements PluginDataHandler {
private final CommonHandleCache selectorCachedHandle = new CommonHandleCache<>();
+ public AiProxyPluginHandler() {
+ }
+
@Override
public void handlerPlugin(final PluginData pluginData) {
// Note: The logic for handling global plugin configuration with Singleton has been removed
@@ -45,6 +50,10 @@ public void handlerPlugin(final PluginData pluginData) {
@Override
public void handlerSelector(final SelectorData selectorData) {
+ // Invalidate the cache first when the selector is updated.
+ OpenAiApiCache.getInstance().remove(selectorData.getId());
+ // Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events
+ // with refreshed realApiKey after selector changes. Removing here introduces a window of misses.
if (Objects.isNull(selectorData.getHandle())) {
return;
}
@@ -56,6 +65,9 @@ public void handlerSelector(final SelectorData selectorData) {
@Override
public void removeSelector(final SelectorData selectorData) {
+ // Invalidate the cache when the selector is removed.
+ OpenAiApiCache.getInstance().remove(selectorData.getId());
+ AiProxyApiKeyCache.getInstance().removeBySelectorId(selectorData.getId());
selectorCachedHandle
.removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE));
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
index d25434881658..bde74f48668e 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
@@ -19,6 +19,7 @@
import org.apache.shenyu.common.dto.ProxyApiKeyData;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache;
import org.apache.shenyu.sync.data.api.AiProxyApiKeyDataSubscriber;
import java.util.Objects;
@@ -47,5 +48,6 @@ public void unSubscribe(final ProxyApiKeyData data) {
@Override
public void refresh() {
AiProxyApiKeyCache.getInstance().refresh();
+ OpenAiApiCache.getInstance().clearAll();
}
}
From 225cf0aa7a16d416d6da2534e2503545087e0d61 Mon Sep 17 00:00:00 2001
From: eye-gu <734164350@qq.com>
Date: Thu, 14 May 2026 18:52:41 +0800
Subject: [PATCH 4/6] fix
---
shenyu-admin/src/main/resources/application.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/shenyu-admin/src/main/resources/application.yml b/shenyu-admin/src/main/resources/application.yml
index 47b4ab7cf3e2..6d2c3af3bfe2 100755
--- a/shenyu-admin/src/main/resources/application.yml
+++ b/shenyu-admin/src/main/resources/application.yml
@@ -21,7 +21,7 @@ spring:
application:
name: shenyu-admin
profiles:
- active: mysql
+ active: h2
thymeleaf:
cache: true
encoding: utf-8
From 37de2d9e0f57844353713886530be90e0683b122 Mon Sep 17 00:00:00 2001
From: eye-gu <734164350@qq.com>
Date: Fri, 15 May 2026 13:25:56 +0800
Subject: [PATCH 5/6] fix review
---
.../protocol/OpenAiProtocolAdapter.java | 41 +++++++--
.../protocol/OpenAiProtocolAdapterTest.java | 21 ++++-
.../ai/proxy/enhanced/AiProxyPlugin.java | 35 +++++---
.../proxy/enhanced/cache/OpenAiApiCache.java | 8 +-
.../handler/AiProxyPluginHandler.java | 2 +-
.../service/AiProxyExecutorService.java | 85 +++++++++++++++----
.../enhanced/service/UpstreamErrorLogger.java | 17 +++-
.../ai/proxy/enhanced/AiProxyPluginTest.java | 18 ++--
.../handler/AiProxyPluginHandlerTest.java | 2 +
.../service/AiProxyExecutorServiceTest.java | 32 ++++---
10 files changed, 201 insertions(+), 60 deletions(-)
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
index ce44a5d3178b..a23eb9cd2ad2 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
@@ -17,6 +17,7 @@
package org.apache.shenyu.plugin.ai.common.protocol;
+import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.shenyu.common.utils.JsonUtils;
@@ -27,9 +28,18 @@
/**
* Adapts between OpenAI Chat Completions wire format and internal representations.
+ *
+ * Note: deserialization into Spring AI's {@link ChatCompletionRequest} preserves fields
+ * modeled by that class (messages, model, temperature, maxTokens, stream, tools, etc.).
+ * Provider-specific extension fields not modeled by {@code ChatCompletionRequest} are dropped
+ * during deserialization. For the OpenAI-compatible providers this plugin currently supports,
+ * all required fields are covered.
*/
public final class OpenAiProtocolAdapter {
+ private static final com.fasterxml.jackson.databind.ObjectMapper MAPPER =
+ new com.fasterxml.jackson.databind.ObjectMapper();
+
private OpenAiProtocolAdapter() {
}
@@ -45,7 +55,7 @@ public static boolean resolveStream(final String requestBody, final Boolean fall
if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
return Boolean.TRUE.equals(fallbackStream);
}
- final JsonNode root = JsonUtils.toJsonNode(requestBody);
+ final JsonNode root = parseStrict(requestBody);
if (Objects.isNull(root)) {
return Boolean.TRUE.equals(fallbackStream);
}
@@ -56,8 +66,8 @@ public static boolean resolveStream(final String requestBody, final Boolean fall
}
/**
- * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields
- * including reasoning_content in assistant messages.
+ * Parse raw request body directly into ChatCompletionRequest, preserving fields
+ * modeled by Spring AI (including reasoning_content in assistant messages).
*
*
Spring AI's createRequest() loses reasoning_content, refusal, and annotations
* when reconstructing ChatCompletionMessage from AssistantMessage.
@@ -74,7 +84,7 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request
}
/**
- * Parse raw request body directly into ChatCompletionRequest, preserving ALL fields.
+ * Parse raw request body directly into ChatCompletionRequest, preserving modeled fields.
* For model, temperature, max_tokens: client request takes priority;
* if missing, falls back to the corresponding field in fallbackConfig.
*
@@ -88,14 +98,19 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request
if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
throw new IllegalArgumentException("Request body must not be empty");
}
- final JsonNode root = JsonUtils.toJsonNode(requestBody);
+ final JsonNode root = parseStrict(requestBody);
if (Objects.isNull(root) || !root.isObject()) {
throw new IllegalArgumentException("Invalid request body: expected a JSON object");
}
final ObjectNode mutableRoot = (ObjectNode) root;
if (root.hasNonNull("max_completion_tokens") && !root.hasNonNull("max_tokens")) {
- mutableRoot.put("max_tokens", root.get("max_completion_tokens").asInt());
+ final JsonNode tokenNode = root.get("max_completion_tokens");
+ if (!tokenNode.isNumber()) {
+ throw new IllegalArgumentException(
+ "max_completion_tokens must be a number, got: " + tokenNode.getNodeType());
+ }
+ mutableRoot.put("max_tokens", tokenNode.asInt());
mutableRoot.remove("max_completion_tokens");
}
@@ -113,6 +128,18 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request
mutableRoot.put("stream", stream);
- return JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class);
+ final ChatCompletionRequest result = JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class);
+ if (Objects.isNull(result)) {
+ throw new IllegalArgumentException("Failed to parse request body into ChatCompletionRequest");
+ }
+ return result;
+ }
+
+ private static JsonNode parseStrict(final String json) {
+ try {
+ return MAPPER.readTree(json);
+ } catch (JsonProcessingException e) {
+ return null;
+ }
}
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
index 566824318014..adfbf4d4bf1f 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
@@ -157,7 +157,7 @@ void testFallbackTemperatureWhenClientMissing() {
}
@Test
- void testNoFallbackTemperatureWhenConfigNull() {
+ void testNoFallbackWhenConfigTemperatureIsNull() {
final AiCommonConfig config = new AiCommonConfig();
config.setTemperature(null);
@@ -235,4 +235,23 @@ void testPartialClientFieldsWithPartialFallback() {
assertEquals(0.1, req.temperature(), 0.001);
assertEquals(2048, req.maxTokens());
}
+
+ @Test
+ void testMalformedJsonThrows() {
+ assertThrows(IllegalArgumentException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest("not valid json{{{", false));
+ }
+
+ @Test
+ void testMaxCompletionTokensNonNumberThrows() {
+ final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":\"abc\"}";
+ assertThrows(IllegalArgumentException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest(body, false));
+ }
+
+ @Test
+ void testResolveStreamMalformedJsonFallsBack() {
+ assertFalse(OpenAiProtocolAdapter.resolveStream("not json{{{", false));
+ assertTrue(OpenAiProtocolAdapter.resolveStream("not json{{{", true));
+ }
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
index b1650f37f3f3..517ff280bc4d 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
@@ -30,6 +30,7 @@
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService.FallbackContext;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.UpstreamErrorLogger;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
@@ -140,7 +141,8 @@ protected Mono doExecute(
}
}
- if (OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream())) {
+ final boolean stream = OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream());
+ if (stream) {
return handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
}
return handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
@@ -155,13 +157,13 @@ private Mono handleStreamRequest(
final AiProxyHandle selectorHandle) {
final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig);
final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig);
- final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle,
- requestBody);
+ final Optional fallbackCtx = resolveFallbackContext(
+ selector.getId(), primaryConfig, selectorHandle, requestBody);
final ServerHttpResponse response = exchange.getResponse();
response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM);
final Flux chunkFlux = aiProxyExecutorService.executeDirectStream(
- mainApi, fallbackApi, request);
+ mainApi, fallbackCtx, request, requestBody, true);
final Flux sseFlux = chunkFlux.map(
chunk -> {
@@ -185,11 +187,11 @@ private Mono handleNonStreamRequest(
final AiProxyHandle selectorHandle) {
final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig);
final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig);
- final Optional fallbackApi = resolveFallbackOpenAiApi(selector.getId(), primaryConfig, selectorHandle,
- requestBody);
+ final Optional fallbackCtx = resolveFallbackContext(
+ selector.getId(), primaryConfig, selectorHandle, requestBody);
return aiProxyExecutorService
- .executeDirectCall(mainApi, fallbackApi, request)
+ .executeDirectCall(mainApi, fallbackCtx, request, requestBody)
.flatMap(
responseEntity -> {
final String responseJson = JsonUtils.toJson(responseEntity.getBody());
@@ -199,7 +201,7 @@ private Mono handleNonStreamRequest(
.doOnError(e -> logUpstreamError(e, "non-stream"));
}
- private Optional resolveFallbackOpenAiApi(
+ private Optional resolveFallbackContext(
final String selectorId,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle,
@@ -211,7 +213,7 @@ private Optional resolveFallbackOpenAiApi(
if (LOG.isDebugEnabled()) {
LOG.debug("[AiProxy] dynamic fallback config: {}", cfg);
}
- return getCachedOpenAiApi(selectorId, "dynamicFallback", cfg);
+ return new FallbackContext(createOpenAiApi(cfg), cfg);
})
.or(() -> aiProxyConfigService
.resolveAdminFallbackConfig(primaryConfig, selectorHandle)
@@ -220,15 +222,26 @@ private Optional resolveFallbackOpenAiApi(
if (LOG.isDebugEnabled()) {
LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig);
}
- return getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig);
+ return new FallbackContext(
+ getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig),
+ adminFallbackConfig);
}));
}
private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type, final AiCommonConfig config) {
- final String cacheKey = selectorId + "|" + type + "_" + config.hashCode();
+ final String cacheKey = selectorId + "|" + type + "_" + generateConfigCacheKey(config);
return OpenAiApiCache.getInstance().computeIfAbsent(cacheKey, () -> createOpenAiApi(config));
}
+ private int generateConfigCacheKey(final AiCommonConfig config) {
+ return Objects.hash(
+ config.getBaseUrl(),
+ config.getModel(),
+ config.getTemperature(),
+ config.getMaxTokens()
+ );
+ }
+
private OpenAiApi createOpenAiApi(final AiCommonConfig config) {
if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) {
throw new IllegalArgumentException("baseUrl must not be empty");
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
index 886b746ad94b..fdbe22f6e429 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
@@ -97,8 +97,10 @@ public OpenAiApi computeIfAbsent(final String key, final Supplier ope
}
/**
- * Evict oldest entries when cache size exceeds limit.
- * Removes approximately 25% of entries to avoid thundering herd problem.
+ * Evict arbitrary entries when cache size exceeds limit.
+ * Note: ConcurrentHashMap iteration order is undefined, so evicted entries are arbitrary,
+ * not guaranteed to be the oldest. Removes approximately 25% of entries to avoid
+ * thundering herd problem.
*/
private void evictOldestEntries() {
final int currentSize = openAiApiMap.size();
@@ -107,7 +109,7 @@ private void evictOldestEntries() {
}
final int evictCount = Math.max(10, currentSize / 4);
- LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} oldest entries",
+ LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} arbitrary entries",
currentSize, MAX_CACHE_SIZE, evictCount);
int removed = 0;
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
index df9dc1a8cb44..cc8a42bbd4b8 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
@@ -54,7 +54,7 @@ public void handlerSelector(final SelectorData selectorData) {
OpenAiApiCache.getInstance().remove(selectorData.getId());
// Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events
// with refreshed realApiKey after selector changes. Removing here introduces a window of misses.
- if (Objects.isNull(selectorData.getHandle())) {
+ if (Objects.isNull(selectorData.getHandle()) || selectorData.getHandle().isEmpty()) {
return;
}
AiProxyHandle aiProxyHandle = GsonUtils.getInstance().fromJson(selectorData.getHandle(), AiProxyHandle.class);
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
index 8502a772f43e..8fbd37be2186 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
@@ -17,6 +17,8 @@
package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.openai.api.OpenAiApi;
@@ -25,12 +27,14 @@
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.retry.NonTransientAiException;
import org.springframework.http.ResponseEntity;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;
import java.time.Duration;
+import java.util.Objects;
import java.util.Optional;
/**
@@ -44,16 +48,20 @@ public class AiProxyExecutorService {
* Execute a streaming AI call directly via {@link OpenAiApi}, bypassing Spring AI's
* {@code createRequest()} which loses fields like {@code reasoning_content}.
*
- * @param mainApi the main OpenAiApi
- * @param fallbackApiOpt the optional fallback OpenAiApi
- * @param request the ChatCompletionRequest with all fields preserved
+ * @param mainApi the main OpenAiApi
+ * @param fallbackCtxOpt the optional fallback context (api + config)
+ * @param request the ChatCompletionRequest with all fields preserved
+ * @param requestBody the original request body for rebuilding fallback request
+ * @param stream whether this is a streaming request
* @return a Flux of ChatCompletionChunk
*/
public Flux executeDirectStream(final OpenAiApi mainApi,
- final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ final Optional fallbackCtxOpt, final ChatCompletionRequest request,
+ final String requestBody, final boolean stream) {
return mainApi.chatCompletionStream(request)
.doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct stream"))
.retryWhen(Retry.max(1)
+ .filter(AiProxyExecutorService::isRetryable)
.onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> {
LOG.warn("Direct stream retry exhausted. Triggering fallback.",
retrySignal.failure());
@@ -62,24 +70,26 @@ public Flux executeDirectStream(final OpenAiApi mainApi,
retrySignal.failure());
}))
.onErrorResume(NonTransientAiException.class,
- throwable -> handleDirectFallbackStream(throwable, fallbackApiOpt, request));
+ throwable -> handleDirectFallbackStream(throwable, fallbackCtxOpt, requestBody, stream));
}
/**
* Execute a non-streaming AI call directly via {@link OpenAiApi}.
*
- * @param mainApi the main OpenAiApi
- * @param fallbackApiOpt the optional fallback OpenAiApi
- * @param request the ChatCompletionRequest with all fields preserved
+ * @param mainApi the main OpenAiApi
+ * @param fallbackCtxOpt the optional fallback context (api + config)
+ * @param request the ChatCompletionRequest with all fields preserved
+ * @param requestBody the original request body for rebuilding fallback request
* @return a Mono of ResponseEntity containing ChatCompletion
*/
public Mono> executeDirectCall(final OpenAiApi mainApi,
- final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ final Optional fallbackCtxOpt, final ChatCompletionRequest request,
+ final String requestBody) {
return Mono.fromCallable(() -> mainApi.chatCompletionEntity(request))
.subscribeOn(Schedulers.boundedElastic())
.doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct call"))
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1))
- .filter(throwable -> !(throwable instanceof NonTransientAiException))
+ .filter(AiProxyExecutorService::isRetryable)
.onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> {
LOG.warn("Direct call retries exhausted after {} attempts. Triggering fallback.",
retrySignal.totalRetries(), retrySignal.failure());
@@ -87,31 +97,72 @@ public Mono> executeDirectCall(final OpenAiApi ma
retrySignal.failure());
}))
.onErrorResume(NonTransientAiException.class,
- throwable -> handleDirectFallbackCall(throwable, fallbackApiOpt, request));
+ throwable -> handleDirectFallbackCall(throwable, fallbackCtxOpt, requestBody));
}
private Flux handleDirectFallbackStream(final Throwable throwable,
- final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ final Optional fallbackCtxOpt, final String requestBody, final boolean stream) {
LOG.warn("Main direct stream failed, attempting fallback...", throwable);
- if (fallbackApiOpt.isEmpty()) {
+ if (fallbackCtxOpt.isEmpty()) {
return Flux.error(throwable);
}
+ final FallbackContext ctx = fallbackCtxOpt.get();
LOG.info("Using fallback OpenAiApi for direct stream");
- return fallbackApiOpt.get().chatCompletionStream(request);
+ final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest(
+ requestBody, stream, ctx.config());
+ return ctx.api().chatCompletionStream(fallbackRequest);
}
private Mono> handleDirectFallbackCall(final Throwable throwable,
- final Optional fallbackApiOpt, final ChatCompletionRequest request) {
+ final Optional fallbackCtxOpt, final String requestBody) {
LOG.warn("Main direct call failed, attempting fallback...", throwable);
- if (fallbackApiOpt.isEmpty()) {
+ if (fallbackCtxOpt.isEmpty()) {
return Mono.error(throwable);
}
+ final FallbackContext ctx = fallbackCtxOpt.get();
LOG.info("Using fallback OpenAiApi for direct call");
- return Mono.fromCallable(() -> fallbackApiOpt.get().chatCompletionEntity(request))
+ final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest(
+ requestBody, false, ctx.config());
+ return Mono.fromCallable(() -> ctx.api().chatCompletionEntity(fallbackRequest))
.subscribeOn(Schedulers.boundedElastic());
}
+
+ /**
+ * Determine if the error is retryable.
+ * Retries transient network errors and retryable HTTP statuses (429, 5xx).
+ * Non-retryable: NonTransientAiException, client errors (400/401/403/404).
+ */
+ private static boolean isRetryable(final Throwable throwable) {
+ if (throwable instanceof NonTransientAiException) {
+ return false;
+ }
+ final WebClientResponseException webClientEx = findWebClientResponseException(throwable);
+ if (Objects.nonNull(webClientEx)) {
+ final int status = webClientEx.getStatusCode().value();
+ return status == 429 || status >= 500;
+ }
+ return true;
+ }
+
+ private static WebClientResponseException findWebClientResponseException(final Throwable e) {
+ Throwable current = e;
+ while (Objects.nonNull(current)) {
+ if (current instanceof WebClientResponseException ex) {
+ return ex;
+ }
+ current = current.getCause();
+ }
+ return null;
+ }
+
+ /**
+ * Container for fallback context: the OpenAiApi instance and the fallback config
+ * used to rebuild the request with fallback-specific model/temperature/maxTokens.
+ */
+ public record FallbackContext(OpenAiApi api, AiCommonConfig config) {
+ }
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
index bee49aab7d0c..bae715f35026 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
@@ -27,19 +27,34 @@
*/
public final class UpstreamErrorLogger {
+ private static final int MAX_BODY_LOG_LENGTH = 512;
+
private UpstreamErrorLogger() {
}
public static void logUpstreamError(final Logger log, final Throwable e, final String mode) {
+ if (Objects.isNull(e)) {
+ return;
+ }
final WebClientResponseException webClientEx = findWebClientResponseException(e);
if (Objects.nonNull(webClientEx)) {
log.error("[AiProxy] {} failed, status={}, upstreamBody={}",
- mode, webClientEx.getStatusCode(), webClientEx.getResponseBodyAsString(), e);
+ mode, webClientEx.getStatusCode(), truncateBody(webClientEx.getResponseBodyAsString()), e);
} else {
log.error("[AiProxy] {} failed", mode, e);
}
}
+ private static String truncateBody(final String body) {
+ if (Objects.isNull(body)) {
+ return "null";
+ }
+ if (body.length() <= MAX_BODY_LOG_LENGTH) {
+ return body;
+ }
+ return body.substring(0, MAX_BODY_LOG_LENGTH) + "...(truncated, total " + body.length() + " chars)";
+ }
+
private static WebClientResponseException findWebClientResponseException(final Throwable e) {
Throwable current = e;
while (Objects.nonNull(current)) {
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
index dfbf95a1b0c9..b3bb711af1d4 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
@@ -131,8 +131,8 @@ private void setupSuccessMocks(final AiProxyHandle handle, final AiCommonConfig
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(fallbackConfig);
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(fallbackConfig);
- when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity));
- when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Flux.empty());
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity));
+ when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class), any(Boolean.class))).thenReturn(Flux.empty());
}
@Test
@@ -151,7 +151,7 @@ public void testExecuteSuccess() {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
@@ -172,7 +172,7 @@ public void testExecuteWithDynamicFallback() {
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
@@ -251,7 +251,7 @@ public void testExecuteWithAdminFallback() {
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
@@ -275,7 +275,7 @@ public void testCacheIsUsedForAdminFallbackClient() {
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty());
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig));
- when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.just(responseEntity));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity));
// Execute the test - focus on successful execution rather than cache verification
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
@@ -284,7 +284,7 @@ public void testCacheIsUsedForAdminFallbackClient() {
// Verify that the configuration methods were called correctly
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
@@ -315,7 +315,7 @@ public void testExecutorServiceError() {
final RuntimeException exception = new RuntimeException("AI execution failed");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
- when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class))).thenReturn(Mono.error(exception));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.error(exception));
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.expectErrorMatches(exception::equals)
@@ -324,6 +324,6 @@ public void testExecutorServiceError() {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class));
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
index 248ae16ac883..5aa2a7a5edf2 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
@@ -78,6 +78,8 @@ void testHandlerSelectorWithEmptyHandle() {
selector.setId("sel-3");
selector.setHandle("");
+ handler.handlerSelector(selector);
+
String key = CacheKeyUtils.INST.getKey("sel-3", Constants.DEFAULT_RULE);
assertNull(handler.getSelectorCachedHandle().obtainHandle(key));
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
index 46d3719cab2a..38597d0dbedf 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
@@ -17,6 +17,7 @@
package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -32,6 +33,7 @@
import java.util.Optional;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -40,6 +42,8 @@
@ExtendWith(MockitoExtension.class)
public class AiProxyExecutorServiceTest {
+ private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}";
+
private AiProxyExecutorService executorService;
@BeforeEach
@@ -54,7 +58,7 @@ void testExecuteDirectStreamSuccess() {
final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
when(mainApi.chatCompletionStream(request)).thenReturn(Flux.just(chunk));
- StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request))
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true))
.expectNext(chunk)
.verifyComplete();
@@ -67,7 +71,7 @@ void testExecuteDirectStreamErrorWithNoFallback() {
final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error")));
- StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request))
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true))
.expectError(NonTransientAiException.class)
.verify();
}
@@ -80,13 +84,17 @@ void testExecuteDirectStreamFallbackSuccess() {
final ChatCompletionChunk fallbackChunk = mock(ChatCompletionChunk.class);
when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error")));
- when(fallbackApi.chatCompletionStream(request)).thenReturn(Flux.just(fallbackChunk));
+ when(fallbackApi.chatCompletionStream(any(ChatCompletionRequest.class))).thenReturn(Flux.just(fallbackChunk));
+
+ final AiCommonConfig fallbackConfig = new AiCommonConfig();
+ fallbackConfig.setModel("fallback-model");
+ final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig);
- StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(fallbackApi), request))
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(ctx), request, REQUEST_BODY, true))
.expectNext(fallbackChunk)
.verifyComplete();
- verify(fallbackApi, times(1)).chatCompletionStream(request);
+ verify(fallbackApi, times(1)).chatCompletionStream(any(ChatCompletionRequest.class));
}
@Test
@@ -97,7 +105,7 @@ void testExecuteDirectCallSuccess() {
final ResponseEntity responseEntity = ResponseEntity.ok(completion);
when(mainApi.chatCompletionEntity(request)).thenReturn(responseEntity);
- StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request))
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY))
.expectNext(responseEntity)
.verifyComplete();
@@ -110,7 +118,7 @@ void testExecuteDirectCallErrorWithNoFallback() {
final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error"));
- StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request))
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY))
.expectError(NonTransientAiException.class)
.verify();
}
@@ -124,12 +132,16 @@ void testExecuteDirectCallFallbackSuccess() {
final ResponseEntity fallbackResponse = ResponseEntity.ok(fallbackCompletion);
when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error"));
- when(fallbackApi.chatCompletionEntity(request)).thenReturn(fallbackResponse);
+ when(fallbackApi.chatCompletionEntity(any(ChatCompletionRequest.class))).thenReturn(fallbackResponse);
+
+ final AiCommonConfig fallbackConfig = new AiCommonConfig();
+ fallbackConfig.setModel("fallback-model");
+ final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig);
- StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(fallbackApi), request))
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(ctx), request, REQUEST_BODY))
.expectNext(fallbackResponse)
.verifyComplete();
- verify(fallbackApi, times(1)).chatCompletionEntity(request);
+ verify(fallbackApi, times(1)).chatCompletionEntity(any(ChatCompletionRequest.class));
}
}
From c6a6dd6cfdd0968b2004127c05b4584dcaff2bfe Mon Sep 17 00:00:00 2001
From: eye-gu <734164350@qq.com>
Date: Sun, 17 May 2026 19:00:25 +0800
Subject: [PATCH 6/6] fix review
---
.../protocol/OpenAiProtocolAdapter.java | 76 +++++++++-------
.../protocol/OpenAiProtocolAdapterTest.java | 67 +++++++++-----
.../ai/proxy/enhanced/AiProxyPlugin.java | 1 +
.../proxy/enhanced/cache/OpenAiApiCache.java | 13 ++-
.../handler/AiProxyPluginHandler.java | 3 +
.../service/AiProxyExecutorService.java | 19 +---
.../enhanced/service/UpstreamErrorLogger.java | 8 +-
.../ai/proxy/enhanced/AiProxyPluginTest.java | 4 +-
.../enhanced/cache/OpenAiApiCacheTest.java | 91 +++++++++++++++++++
.../service/UpstreamErrorLoggerTest.java | 15 +++
10 files changed, 222 insertions(+), 75 deletions(-)
create mode 100644 shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
index a23eb9cd2ad2..73244c838f82 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
@@ -20,8 +20,10 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
-import org.apache.shenyu.common.utils.JsonUtils;
+import org.apache.shenyu.common.exception.ShenyuException;
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import java.util.Objects;
@@ -37,9 +39,24 @@
*/
public final class OpenAiProtocolAdapter {
+ private static final Logger LOG = LoggerFactory.getLogger(OpenAiProtocolAdapter.class);
+
private static final com.fasterxml.jackson.databind.ObjectMapper MAPPER =
new com.fasterxml.jackson.databind.ObjectMapper();
+ /**
+ * OpenAI API field name constants.
+ */
+ private static final String FIELD_MODEL = "model";
+
+ private static final String FIELD_TEMPERATURE = "temperature";
+
+ private static final String FIELD_STREAM = "stream";
+
+ private static final String FIELD_MAX_COMPLETION_TOKENS = "max_completion_tokens";
+
+ private static final String FIELD_MAX_TOKENS = "max_tokens";
+
private OpenAiProtocolAdapter() {
}
@@ -59,8 +76,8 @@ public static boolean resolveStream(final String requestBody, final Boolean fall
if (Objects.isNull(root)) {
return Boolean.TRUE.equals(fallbackStream);
}
- if (root.hasNonNull("stream")) {
- return root.get("stream").asBoolean();
+ if (root.hasNonNull(FIELD_STREAM)) {
+ return root.get(FIELD_STREAM).asBoolean();
}
return Boolean.TRUE.equals(fallbackStream);
}
@@ -73,11 +90,13 @@ public static boolean resolveStream(final String requestBody, final Boolean fall
* when reconstructing ChatCompletionMessage from AssistantMessage.
* This method avoids that loss by deserializing the raw JSON directly.
*
- * Also converts max_completion_tokens to max_tokens for broader API compatibility.
+ *
For max_tokens and max_completion_tokens: client values are preserved as-is;
+ * only when the client omits both fields does the fallbackConfig value populate both,
+ * ensuring compatibility with providers that require either field.
*
* @param requestBody the raw JSON request body in OpenAI Chat Completions format
* @param stream whether this is a streaming request (sets the stream field)
- * @return a ChatCompletionRequest with all fields preserved from the original request
+ * @return a ChatCompletionRequest with modeled fields preserved from the original request
*/
public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, final boolean stream) {
return toChatCompletionRequest(requestBody, stream, null);
@@ -85,54 +104,47 @@ public static ChatCompletionRequest toChatCompletionRequest(final String request
/**
* Parse raw request body directly into ChatCompletionRequest, preserving modeled fields.
- * For model, temperature, max_tokens: client request takes priority;
- * if missing, falls back to the corresponding field in fallbackConfig.
+ * When fallbackConfig is provided, its non-null fields (model, temperature, maxTokens)
+ * override the client request values, matching the original ChatModel-based fallback behavior
+ * where the fallback ChatModel's config takes precedence.
*
* @param requestBody the raw JSON request body in OpenAI Chat Completions format
* @param stream whether this is a streaming request
- * @param fallbackConfig the admin config used as fallback when client omits fields
- * @return a ChatCompletionRequest with all fields preserved
+ * @param fallbackConfig the fallback config whose non-null fields override client values
+ * @return a ChatCompletionRequest with modeled fields preserved
*/
public static ChatCompletionRequest toChatCompletionRequest(final String requestBody,
final boolean stream, final AiCommonConfig fallbackConfig) {
if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
- throw new IllegalArgumentException("Request body must not be empty");
+ throw new ShenyuException("Request body must not be empty");
}
final JsonNode root = parseStrict(requestBody);
if (Objects.isNull(root) || !root.isObject()) {
- throw new IllegalArgumentException("Invalid request body: expected a JSON object");
+ throw new ShenyuException("Invalid request body: expected a JSON object");
}
final ObjectNode mutableRoot = (ObjectNode) root;
- if (root.hasNonNull("max_completion_tokens") && !root.hasNonNull("max_tokens")) {
- final JsonNode tokenNode = root.get("max_completion_tokens");
- if (!tokenNode.isNumber()) {
- throw new IllegalArgumentException(
- "max_completion_tokens must be a number, got: " + tokenNode.getNodeType());
- }
- mutableRoot.put("max_tokens", tokenNode.asInt());
- mutableRoot.remove("max_completion_tokens");
- }
-
if (Objects.nonNull(fallbackConfig)) {
- if (!root.hasNonNull("model") && Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) {
- mutableRoot.put("model", fallbackConfig.getModel());
+ if (Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) {
+ mutableRoot.put(FIELD_MODEL, fallbackConfig.getModel());
}
- if (!root.hasNonNull("temperature") && Objects.nonNull(fallbackConfig.getTemperature())) {
- mutableRoot.put("temperature", fallbackConfig.getTemperature());
+ if (Objects.nonNull(fallbackConfig.getTemperature())) {
+ mutableRoot.put(FIELD_TEMPERATURE, fallbackConfig.getTemperature());
}
- if (!root.hasNonNull("max_tokens") && Objects.nonNull(fallbackConfig.getMaxTokens())) {
- mutableRoot.put("max_tokens", fallbackConfig.getMaxTokens());
+ if (Objects.nonNull(fallbackConfig.getMaxTokens())) {
+ mutableRoot.put(FIELD_MAX_TOKENS, fallbackConfig.getMaxTokens());
+ mutableRoot.put(FIELD_MAX_COMPLETION_TOKENS, fallbackConfig.getMaxTokens());
}
}
- mutableRoot.put("stream", stream);
+ mutableRoot.put(FIELD_STREAM, stream);
- final ChatCompletionRequest result = JsonUtils.jsonToObject(mutableRoot.toString(), ChatCompletionRequest.class);
- if (Objects.isNull(result)) {
- throw new IllegalArgumentException("Failed to parse request body into ChatCompletionRequest");
+ try {
+ return MAPPER.treeToValue(mutableRoot, ChatCompletionRequest.class);
+ } catch (Exception e) {
+ LOG.error("[AiProxy] Failed to parse request body into ChatCompletionRequest", e);
+ throw new ShenyuException("Failed to parse request body into ChatCompletionRequest", e);
}
- return result;
}
private static JsonNode parseStrict(final String json) {
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
index adfbf4d4bf1f..e843d95f1d0b 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
@@ -17,6 +17,7 @@
package org.apache.shenyu.plugin.ai.common.protocol;
+import org.apache.shenyu.common.exception.ShenyuException;
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
import org.junit.jupiter.api.Test;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
@@ -33,13 +34,13 @@ public final class OpenAiProtocolAdapterTest {
@Test
void testNullRequestBody() {
- assertThrows(IllegalArgumentException.class,
+ assertThrows(ShenyuException.class,
() -> OpenAiProtocolAdapter.toChatCompletionRequest(null, false));
}
@Test
void testEmptyRequestBody() {
- assertThrows(IllegalArgumentException.class,
+ assertThrows(ShenyuException.class,
() -> OpenAiProtocolAdapter.toChatCompletionRequest("", false));
}
@@ -87,21 +88,21 @@ void testStreamFlagUnset() {
}
@Test
- void testMaxCompletionTokensConvertedToMaxTokens() {
+ void testMaxCompletionTokensPreservedAsIs() {
final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":100}";
final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false);
assertNotNull(req);
- assertEquals(100, req.maxTokens());
+ assertEquals(100, req.maxCompletionTokens());
}
@Test
- void testClientModelTakesPriority() {
+ void testFallbackModelOverridesClientModel() {
final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}";
final AiCommonConfig config = new AiCommonConfig();
- config.setModel("admin-model");
+ config.setModel("fallback-model");
final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
- assertEquals("client-model", req.model());
+ assertEquals("fallback-model", req.model());
}
@Test
@@ -138,13 +139,13 @@ void testNoFallbackWhenConfigModelIsEmpty() {
}
@Test
- void testClientTemperatureTakesPriority() {
+ void testFallbackTemperatureOverridesClient() {
final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.5}";
final AiCommonConfig config = new AiCommonConfig();
config.setTemperature(0.9);
final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
- assertEquals(0.5, req.temperature(), 0.001);
+ assertEquals(0.9, req.temperature(), 0.001);
}
@Test
@@ -166,13 +167,13 @@ void testNoFallbackWhenConfigTemperatureIsNull() {
}
@Test
- void testClientMaxTokensTakesPriority() {
+ void testFallbackMaxTokensOverridesClient() {
final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_tokens\":200}";
final AiCommonConfig config = new AiCommonConfig();
config.setMaxTokens(500);
final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
- assertEquals(200, req.maxTokens());
+ assertEquals(500, req.maxTokens());
}
@Test
@@ -208,44 +209,44 @@ void testAllFallbackFieldsAppliedWhenClientMissingAll() {
}
@Test
- void testAllClientFieldsTakePriority() {
+ void testAllFallbackFieldsOverrideClient() {
final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],"
+ "\"temperature\":0.1,\"max_tokens\":50}";
final AiCommonConfig config = new AiCommonConfig();
- config.setModel("admin-model");
+ config.setModel("fallback-model");
config.setTemperature(0.9);
config.setMaxTokens(9999);
final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
- assertEquals("client-model", req.model());
- assertEquals(0.1, req.temperature(), 0.001);
- assertEquals(50, req.maxTokens());
+ assertEquals("fallback-model", req.model());
+ assertEquals(0.9, req.temperature(), 0.001);
+ assertEquals(9999, req.maxTokens());
}
@Test
- void testPartialClientFieldsWithPartialFallback() {
+ void testPartialFallbackOverrideWithPartialClient() {
final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.1}";
final AiCommonConfig config = new AiCommonConfig();
- config.setModel("admin-model");
+ config.setModel("fallback-model");
config.setTemperature(0.9);
config.setMaxTokens(2048);
final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
- assertEquals("client-model", req.model());
- assertEquals(0.1, req.temperature(), 0.001);
+ assertEquals("fallback-model", req.model());
+ assertEquals(0.9, req.temperature(), 0.001);
assertEquals(2048, req.maxTokens());
}
@Test
void testMalformedJsonThrows() {
- assertThrows(IllegalArgumentException.class,
+ assertThrows(ShenyuException.class,
() -> OpenAiProtocolAdapter.toChatCompletionRequest("not valid json{{{", false));
}
@Test
void testMaxCompletionTokensNonNumberThrows() {
final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":\"abc\"}";
- assertThrows(IllegalArgumentException.class,
+ assertThrows(ShenyuException.class,
() -> OpenAiProtocolAdapter.toChatCompletionRequest(body, false));
}
@@ -254,4 +255,26 @@ void testResolveStreamMalformedJsonFallsBack() {
assertFalse(OpenAiProtocolAdapter.resolveStream("not json{{{", false));
assertTrue(OpenAiProtocolAdapter.resolveStream("not json{{{", true));
}
+
+ @Test
+ void testAnnotationsRoundTrip() throws Exception {
+ final com.fasterxml.jackson.databind.ObjectMapper mapper = new com.fasterxml.jackson.databind.ObjectMapper();
+ final String body = "{\"messages\":["
+ + "{\"role\":\"user\",\"content\":\"hello\"},"
+ + "{\"role\":\"assistant\",\"content\":\"see link\","
+ + "\"annotations\":[{\"type\":\"url_citation\","
+ + "\"url_citation\":{\"url\":\"https://example.com\",\"title\":\"Example\",\"start_index\":0,\"end_index\":8}}]}"
+ + "]}";
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false);
+ assertNotNull(req);
+
+ final String serialized = mapper.writeValueAsString(req);
+ final com.fasterxml.jackson.databind.JsonNode root = mapper.readTree(serialized);
+ final com.fasterxml.jackson.databind.JsonNode assistantMsg = root.get("messages").get(1);
+ assertTrue(assistantMsg.has("annotations"));
+ final com.fasterxml.jackson.databind.JsonNode annotations = assistantMsg.get("annotations");
+ assertEquals(1, annotations.size());
+ assertEquals("url_citation", annotations.get(0).get("type").asText());
+ assertEquals("https://example.com", annotations.get(0).get("url_citation").get("url").asText());
+ }
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
index 517ff280bc4d..916431437acd 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
@@ -236,6 +236,7 @@ private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type,
private int generateConfigCacheKey(final AiCommonConfig config) {
return Objects.hash(
config.getBaseUrl(),
+ config.getApiKey(),
config.getModel(),
config.getTemperature(),
config.getMaxTokens()
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
index fdbe22f6e429..d75aab480f32 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
@@ -85,7 +85,7 @@ public OpenAiApi computeIfAbsent(final String key, final Supplier ope
try {
synchronized (openAiApiMap) {
if (openAiApiMap.size() > MAX_CACHE_SIZE) {
- evictOldestEntries();
+ evictEntries();
}
}
} finally {
@@ -102,7 +102,7 @@ public OpenAiApi computeIfAbsent(final String key, final Supplier ope
* not guaranteed to be the oldest. Removes approximately 25% of entries to avoid
* thundering herd problem.
*/
- private void evictOldestEntries() {
+ private void evictEntries() {
final int currentSize = openAiApiMap.size();
if (currentSize <= MAX_CACHE_SIZE) {
return;
@@ -146,4 +146,13 @@ public void clearAll() {
openAiApiMap.clear();
LOG.info("[OpenAiApiCache] cleared all cached OpenAiApi instances");
}
+
+ /**
+ * Gets the current cache size (for testing / monitoring).
+ *
+ * @return the number of cached entries
+ */
+ public int size() {
+ return openAiApiMap.size();
+ }
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
index cc8a42bbd4b8..40dfc31fa77f 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
@@ -55,6 +55,9 @@ public void handlerSelector(final SelectorData selectorData) {
// Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events
// with refreshed realApiKey after selector changes. Removing here introduces a window of misses.
if (Objects.isNull(selectorData.getHandle()) || selectorData.getHandle().isEmpty()) {
+ // Clear stale cached handle when selector handle is removed/cleared
+ selectorCachedHandle
+ .removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE));
return;
}
AiProxyHandle aiProxyHandle = GsonUtils.getInstance().fromJson(selectorData.getHandle(), AiProxyHandle.class);
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
index 8fbd37be2186..534f81f0c7f7 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
@@ -69,8 +69,7 @@ public Flux executeDirectStream(final OpenAiApi mainApi,
"Direct stream failed after 1 retry. Triggering fallback.",
retrySignal.failure());
}))
- .onErrorResume(NonTransientAiException.class,
- throwable -> handleDirectFallbackStream(throwable, fallbackCtxOpt, requestBody, stream));
+ .onErrorResume(e -> handleDirectFallbackStream(e, fallbackCtxOpt, requestBody, stream));
}
/**
@@ -96,8 +95,7 @@ public Mono> executeDirectCall(final OpenAiApi ma
return new NonTransientAiException("Direct call retries exhausted. Triggering fallback.",
retrySignal.failure());
}))
- .onErrorResume(NonTransientAiException.class,
- throwable -> handleDirectFallbackCall(throwable, fallbackCtxOpt, requestBody));
+ .onErrorResume(e -> handleDirectFallbackCall(e, fallbackCtxOpt, requestBody));
}
private Flux handleDirectFallbackStream(final Throwable throwable,
@@ -140,7 +138,7 @@ private static boolean isRetryable(final Throwable throwable) {
if (throwable instanceof NonTransientAiException) {
return false;
}
- final WebClientResponseException webClientEx = findWebClientResponseException(throwable);
+ final WebClientResponseException webClientEx = UpstreamErrorLogger.findWebClientResponseException(throwable);
if (Objects.nonNull(webClientEx)) {
final int status = webClientEx.getStatusCode().value();
return status == 429 || status >= 500;
@@ -148,17 +146,6 @@ private static boolean isRetryable(final Throwable throwable) {
return true;
}
- private static WebClientResponseException findWebClientResponseException(final Throwable e) {
- Throwable current = e;
- while (Objects.nonNull(current)) {
- if (current instanceof WebClientResponseException ex) {
- return ex;
- }
- current = current.getCause();
- }
- return null;
- }
-
/**
* Container for fallback context: the OpenAiApi instance and the fallback config
* used to rebuild the request with fallback-specific model/temperature/maxTokens.
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
index bae715f35026..06a08e0c378a 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
@@ -55,7 +55,13 @@ private static String truncateBody(final String body) {
return body.substring(0, MAX_BODY_LOG_LENGTH) + "...(truncated, total " + body.length() + " chars)";
}
- private static WebClientResponseException findWebClientResponseException(final Throwable e) {
+ /**
+ * Find the {@link WebClientResponseException} in the exception chain.
+ *
+ * @param e the root exception
+ * @return the WebClientResponseException if found, null otherwise
+ */
+ public static WebClientResponseException findWebClientResponseException(final Throwable e) {
Throwable current = e;
while (Objects.nonNull(current)) {
if (current instanceof WebClientResponseException ex) {
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
index b3bb711af1d4..e61f43e1689d 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
@@ -255,7 +255,7 @@ public void testExecuteWithAdminFallback() {
}
@Test
- public void testCacheIsUsedForAdminFallbackClient() {
+ public void testAdminFallbackExecution() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
@@ -277,7 +277,7 @@ public void testCacheIsUsedForAdminFallbackClient() {
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig));
when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity));
- // Execute the test - focus on successful execution rather than cache verification
+ // Execute the test - verify admin fallback path is exercised
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java
new file mode 100644
index 000000000000..11bed2897ddb
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.cache;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.openai.api.OpenAiApi;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertSame;
+
+class OpenAiApiCacheTest {
+
+ private OpenAiApiCache cache;
+
+ @BeforeEach
+ void setUp() {
+ cache = new OpenAiApiCache();
+ cache.clearAll();
+ }
+
+ @Test
+ void testComputeIfAbsentReusesExisting() {
+ final OpenAiApi api = createTestApi("https://api.openai.com", "key1");
+ final OpenAiApi cached = cache.computeIfAbsent("key1", () -> api);
+ final OpenAiApi reused = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1-other"));
+
+ assertSame(cached, reused);
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ void testComputeIfAbsentCreatesNewForDifferentKey() {
+ final OpenAiApi api1 = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1"));
+ final OpenAiApi api2 = cache.computeIfAbsent("key2", () -> createTestApi("https://api.deepseek.com", "key2"));
+
+ // Different keys produce different instances
+ assertEquals(2, cache.size());
+ }
+
+ @Test
+ void testRemoveBySelectorId() {
+ cache.computeIfAbsent("sel1|main_123", () -> createTestApi("https://a.com", "k1"));
+ cache.computeIfAbsent("sel1|adminFallback_456", () -> createTestApi("https://b.com", "k2"));
+ cache.computeIfAbsent("sel2|main_789", () -> createTestApi("https://c.com", "k3"));
+
+ assertEquals(3, cache.size());
+
+ cache.remove("sel1");
+
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ void testRemoveWithNullSelectorIdDoesNotThrow() {
+ cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1"));
+ cache.remove(null);
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ void testClearAll() {
+ cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1"));
+ cache.computeIfAbsent("key2", () -> createTestApi("https://b.com", "k2"));
+
+ assertEquals(2, cache.size());
+
+ cache.clearAll();
+
+ assertEquals(0, cache.size());
+ }
+
+ private OpenAiApi createTestApi(final String baseUrl, final String apiKey) {
+ return OpenAiApi.builder().baseUrl(baseUrl).apiKey(apiKey).build();
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
index 46d21f43c9a1..85193687d1df 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
@@ -70,4 +70,19 @@ void testLogWithNullExceptionDoesNotThrow() {
Logger log = mock(Logger.class);
assertDoesNotThrow(() -> UpstreamErrorLogger.logUpstreamError(log, null, "stream"));
}
+
+ @Test
+ void testLogWithLongBodyIsTruncated() {
+ Logger log = mock(Logger.class);
+ final String longBody = "x".repeat(1024);
+ WebClientResponseException ex = WebClientResponseException.create(
+ 500, "Internal Server Error", null, longBody.getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8);
+
+ UpstreamErrorLogger.logUpstreamError(log, ex, "stream");
+
+ final String expectedTruncated = longBody.substring(0, 512)
+ + "...(truncated, total " + longBody.length() + " chars)";
+ verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ "stream", HttpStatus.INTERNAL_SERVER_ERROR, expectedTruncated, ex);
+ }
}