diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java
index c3d30221efe0..71a62e15636f 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/config/AiCommonConfig.java
@@ -50,7 +50,7 @@ public class AiCommonConfig {
/**
* temperature.
*/
- private Double temperature = 0.8;
+ private Double temperature;
/**
* max tokens.
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
new file mode 100644
index 000000000000..73244c838f82
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapter.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.common.protocol;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import org.apache.shenyu.common.exception.ShenyuException;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
+
+import java.util.Objects;
+
+/**
+ * Adapts between OpenAI Chat Completions wire format and internal representations.
+ *
+ *
Note: deserialization into Spring AI's {@link ChatCompletionRequest} preserves fields
+ * modeled by that class (messages, model, temperature, maxTokens, stream, tools, etc.).
+ * Provider-specific extension fields not modeled by {@code ChatCompletionRequest} are dropped
+ * during deserialization. For the OpenAI-compatible providers this plugin currently supports,
+ * all required fields are covered.
+ */
+public final class OpenAiProtocolAdapter {
+
+ private static final Logger LOG = LoggerFactory.getLogger(OpenAiProtocolAdapter.class);
+
+ private static final com.fasterxml.jackson.databind.ObjectMapper MAPPER =
+ new com.fasterxml.jackson.databind.ObjectMapper();
+
+ /**
+ * OpenAI API field name constants.
+ */
+ private static final String FIELD_MODEL = "model";
+
+ private static final String FIELD_TEMPERATURE = "temperature";
+
+ private static final String FIELD_STREAM = "stream";
+
+ private static final String FIELD_MAX_COMPLETION_TOKENS = "max_completion_tokens";
+
+ private static final String FIELD_MAX_TOKENS = "max_tokens";
+
+ private OpenAiProtocolAdapter() {
+ }
+
+ /**
+ * Resolve the stream flag: client request body takes priority,
+ * falls back to the provided default value.
+ *
+ * @param requestBody the raw JSON request body
+ * @param fallbackStream the default stream value from admin config
+ * @return true if streaming, false otherwise
+ */
+ public static boolean resolveStream(final String requestBody, final Boolean fallbackStream) {
+ if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
+ return Boolean.TRUE.equals(fallbackStream);
+ }
+ final JsonNode root = parseStrict(requestBody);
+ if (Objects.isNull(root)) {
+ return Boolean.TRUE.equals(fallbackStream);
+ }
+ if (root.hasNonNull(FIELD_STREAM)) {
+ return root.get(FIELD_STREAM).asBoolean();
+ }
+ return Boolean.TRUE.equals(fallbackStream);
+ }
+
+ /**
+ * Parse raw request body directly into ChatCompletionRequest, preserving fields
+ * modeled by Spring AI (including reasoning_content in assistant messages).
+ *
+ *
Spring AI's createRequest() loses reasoning_content, refusal, and annotations
+ * when reconstructing ChatCompletionMessage from AssistantMessage.
+ * This method avoids that loss by deserializing the raw JSON directly.
+ *
+ *
For max_tokens and max_completion_tokens: client values are preserved as-is;
+ * only when the client omits both fields does the fallbackConfig value populate both,
+ * ensuring compatibility with providers that require either field.
+ *
+ * @param requestBody the raw JSON request body in OpenAI Chat Completions format
+ * @param stream whether this is a streaming request (sets the stream field)
+ * @return a ChatCompletionRequest with modeled fields preserved from the original request
+ */
+ public static ChatCompletionRequest toChatCompletionRequest(final String requestBody, final boolean stream) {
+ return toChatCompletionRequest(requestBody, stream, null);
+ }
+
+ /**
+ * Parse raw request body directly into ChatCompletionRequest, preserving modeled fields.
+ * When fallbackConfig is provided, its non-null fields (model, temperature, maxTokens)
+ * override the client request values, matching the original ChatModel-based fallback behavior
+ * where the fallback ChatModel's config takes precedence.
+ *
+ * @param requestBody the raw JSON request body in OpenAI Chat Completions format
+ * @param stream whether this is a streaming request
+ * @param fallbackConfig the fallback config whose non-null fields override client values
+ * @return a ChatCompletionRequest with modeled fields preserved
+ */
+ public static ChatCompletionRequest toChatCompletionRequest(final String requestBody,
+ final boolean stream, final AiCommonConfig fallbackConfig) {
+ if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
+ throw new ShenyuException("Request body must not be empty");
+ }
+ final JsonNode root = parseStrict(requestBody);
+ if (Objects.isNull(root) || !root.isObject()) {
+ throw new ShenyuException("Invalid request body: expected a JSON object");
+ }
+ final ObjectNode mutableRoot = (ObjectNode) root;
+
+ if (Objects.nonNull(fallbackConfig)) {
+ if (Objects.nonNull(fallbackConfig.getModel()) && !fallbackConfig.getModel().isEmpty()) {
+ mutableRoot.put(FIELD_MODEL, fallbackConfig.getModel());
+ }
+ if (Objects.nonNull(fallbackConfig.getTemperature())) {
+ mutableRoot.put(FIELD_TEMPERATURE, fallbackConfig.getTemperature());
+ }
+ if (Objects.nonNull(fallbackConfig.getMaxTokens())) {
+ mutableRoot.put(FIELD_MAX_TOKENS, fallbackConfig.getMaxTokens());
+ mutableRoot.put(FIELD_MAX_COMPLETION_TOKENS, fallbackConfig.getMaxTokens());
+ }
+ }
+
+ mutableRoot.put(FIELD_STREAM, stream);
+
+ try {
+ return MAPPER.treeToValue(mutableRoot, ChatCompletionRequest.class);
+ } catch (Exception e) {
+ LOG.error("[AiProxy] Failed to parse request body into ChatCompletionRequest", e);
+ throw new ShenyuException("Failed to parse request body into ChatCompletionRequest", e);
+ }
+ }
+
+ private static JsonNode parseStrict(final String json) {
+ try {
+ return MAPPER.readTree(json);
+ } catch (JsonProcessingException e) {
+ return null;
+ }
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java
deleted file mode 100644
index 08c68c4f46ec..000000000000
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/FallbackStrategy.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shenyu.plugin.ai.common.strategy;
-
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatResponse;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-
-/**
- * Fallback strategy.
- */
-public interface FallbackStrategy {
-
- /**
- * Execute the fallback strategy.
- *
- * @param fallbackClient the pre-configured and cached chat client for fallback
- * @param requestBody the original request body as a string
- * @param originalError the original error that triggered the fallback
- * @return a Mono containing the fallback ChatResponse
- */
- Mono fallback(ChatClient fallbackClient, String requestBody, Throwable originalError);
-
- /**
- * Execute the fallback strategy for stream.
- *
- * @param fallbackClient the pre-configured and cached chat client for fallback
- * @param requestBody the original request body as a string
- * @param originalError the original error that triggered the fallback
- * @return a Flux containing the fallback ChatResponse
- */
- Flux fallbackStream(ChatClient fallbackClient, String requestBody, Throwable originalError);
-}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java
deleted file mode 100644
index 94b3653e654d..000000000000
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/strategy/SimpleModelFallbackStrategy.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shenyu.plugin.ai.common.strategy;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatResponse;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-import reactor.core.scheduler.Schedulers;
-
-/**
- * A fallback strategy that simply executes a call with a pre-configured client.
- */
-public final class SimpleModelFallbackStrategy implements FallbackStrategy {
-
- /**
- * The constant INSTANCE.
- */
- public static final FallbackStrategy INSTANCE = new SimpleModelFallbackStrategy();
-
- private static final Logger LOG = LoggerFactory.getLogger(SimpleModelFallbackStrategy.class);
-
- private SimpleModelFallbackStrategy() {
- }
-
- @Override
- public Mono fallback(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) {
- LOG.warn("Executing simple model fallback strategy due to error: {}", originalError.getMessage());
-
- return Mono.fromCallable(() -> fallbackClient.prompt()
- .user(requestBody)
- .call()
- .chatResponse())
- .subscribeOn(Schedulers.boundedElastic())
- .doOnSuccess(response -> LOG.info("Fallback call successful."))
- .doOnError(fallbackError -> LOG.error("Fallback call also failed.", fallbackError));
- }
-
- @Override
- public Flux fallbackStream(final ChatClient fallbackClient, final String requestBody, final Throwable originalError) {
- LOG.warn("Executing simple model fallback stream strategy due to error: {}", originalError.getMessage());
-
- return Flux.defer(() -> fallbackClient.prompt()
- .user(requestBody)
- .stream()
- .chatResponse())
- .subscribeOn(Schedulers.boundedElastic())
- .doOnComplete(() -> LOG.info("Fallback stream completed successfully."))
- .doOnError(fallbackError -> LOG.error("Fallback stream also failed.", fallbackError));
- }
-}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
new file mode 100644
index 000000000000..e843d95f1d0b
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/test/java/org/apache/shenyu/plugin/ai/common/protocol/OpenAiProtocolAdapterTest.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.common.protocol;
+
+import org.apache.shenyu.common.exception.ShenyuException;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public final class OpenAiProtocolAdapterTest {
+
+ private static final String BASE_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}";
+
+ @Test
+ void testNullRequestBody() {
+ assertThrows(ShenyuException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest(null, false));
+ }
+
+ @Test
+ void testEmptyRequestBody() {
+ assertThrows(ShenyuException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest("", false));
+ }
+
+ @Test
+ void testStreamFlagSet() {
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true);
+ assertNotNull(req);
+ assertTrue(req.stream());
+ }
+
+ @Test
+ void testResolveStreamClientTrueOverridesFallbackFalse() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":true}";
+ assertTrue(OpenAiProtocolAdapter.resolveStream(body, false));
+ }
+
+ @Test
+ void testResolveStreamClientFalseOverridesFallbackTrue() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"stream\":false}";
+ assertFalse(OpenAiProtocolAdapter.resolveStream(body, true));
+ }
+
+ @Test
+ void testResolveStreamFallbackWhenClientMissing() {
+ assertTrue(OpenAiProtocolAdapter.resolveStream(BASE_BODY, true));
+ assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, false));
+ }
+
+ @Test
+ void testResolveStreamFallbackWhenClientMissingAndConfigNull() {
+ assertFalse(OpenAiProtocolAdapter.resolveStream(BASE_BODY, null));
+ }
+
+ @Test
+ void testResolveStreamFallbackWhenBodyEmpty() {
+ assertTrue(OpenAiProtocolAdapter.resolveStream("", true));
+ assertFalse(OpenAiProtocolAdapter.resolveStream("", false));
+ }
+
+ @Test
+ void testStreamFlagUnset() {
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false);
+ assertNotNull(req);
+ assertFalse(req.stream());
+ }
+
+ @Test
+ void testMaxCompletionTokensPreservedAsIs() {
+ final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":100}";
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false);
+ assertNotNull(req);
+ assertEquals(100, req.maxCompletionTokens());
+ }
+
+ @Test
+ void testFallbackModelOverridesClientModel() {
+ final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("fallback-model");
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals("fallback-model", req.model());
+ }
+
+ @Test
+ void testFallbackModelWhenClientMissing() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("admin-model");
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertEquals("admin-model", req.model());
+ }
+
+ @Test
+ void testNoFallbackModelWhenConfigIsNull() {
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, null);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testNoFallbackWhenConfigModelIsNull() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel(null);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testNoFallbackWhenConfigModelIsEmpty() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("");
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testFallbackTemperatureOverridesClient() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.5}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setTemperature(0.9);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals(0.9, req.temperature(), 0.001);
+ }
+
+ @Test
+ void testFallbackTemperatureWhenClientMissing() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setTemperature(0.7);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertEquals(0.7, req.temperature(), 0.001);
+ }
+
+ @Test
+ void testNoFallbackWhenConfigTemperatureIsNull() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setTemperature(null);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testFallbackMaxTokensOverridesClient() {
+ final String body = "{\"model\":\"m\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_tokens\":200}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setMaxTokens(500);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals(500, req.maxTokens());
+ }
+
+ @Test
+ void testFallbackMaxTokensWhenClientMissing() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setMaxTokens(500);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertEquals(500, req.maxTokens());
+ }
+
+ @Test
+ void testNoFallbackMaxTokensWhenConfigNull() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setMaxTokens(null);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, false, config);
+ assertNotNull(req);
+ }
+
+ @Test
+ void testAllFallbackFieldsAppliedWhenClientMissingAll() {
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("admin-model");
+ config.setTemperature(0.3);
+ config.setMaxTokens(1024);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(BASE_BODY, true, config);
+ assertEquals("admin-model", req.model());
+ assertEquals(0.3, req.temperature(), 0.001);
+ assertEquals(1024, req.maxTokens());
+ assertTrue(req.stream());
+ }
+
+ @Test
+ void testAllFallbackFieldsOverrideClient() {
+ final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],"
+ + "\"temperature\":0.1,\"max_tokens\":50}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("fallback-model");
+ config.setTemperature(0.9);
+ config.setMaxTokens(9999);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals("fallback-model", req.model());
+ assertEquals(0.9, req.temperature(), 0.001);
+ assertEquals(9999, req.maxTokens());
+ }
+
+ @Test
+ void testPartialFallbackOverrideWithPartialClient() {
+ final String body = "{\"model\":\"client-model\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"temperature\":0.1}";
+ final AiCommonConfig config = new AiCommonConfig();
+ config.setModel("fallback-model");
+ config.setTemperature(0.9);
+ config.setMaxTokens(2048);
+
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false, config);
+ assertEquals("fallback-model", req.model());
+ assertEquals(0.9, req.temperature(), 0.001);
+ assertEquals(2048, req.maxTokens());
+ }
+
+ @Test
+ void testMalformedJsonThrows() {
+ assertThrows(ShenyuException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest("not valid json{{{", false));
+ }
+
+ @Test
+ void testMaxCompletionTokensNonNumberThrows() {
+ final String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}],\"max_completion_tokens\":\"abc\"}";
+ assertThrows(ShenyuException.class,
+ () -> OpenAiProtocolAdapter.toChatCompletionRequest(body, false));
+ }
+
+ @Test
+ void testResolveStreamMalformedJsonFallsBack() {
+ assertFalse(OpenAiProtocolAdapter.resolveStream("not json{{{", false));
+ assertTrue(OpenAiProtocolAdapter.resolveStream("not json{{{", true));
+ }
+
+ @Test
+ void testAnnotationsRoundTrip() throws Exception {
+ final com.fasterxml.jackson.databind.ObjectMapper mapper = new com.fasterxml.jackson.databind.ObjectMapper();
+ final String body = "{\"messages\":["
+ + "{\"role\":\"user\",\"content\":\"hello\"},"
+ + "{\"role\":\"assistant\",\"content\":\"see link\","
+ + "\"annotations\":[{\"type\":\"url_citation\","
+ + "\"url_citation\":{\"url\":\"https://example.com\",\"title\":\"Example\",\"start_index\":0,\"end_index\":8}}]}"
+ + "]}";
+ final ChatCompletionRequest req = OpenAiProtocolAdapter.toChatCompletionRequest(body, false);
+ assertNotNull(req);
+
+ final String serialized = mapper.writeValueAsString(req);
+ final com.fasterxml.jackson.databind.JsonNode root = mapper.readTree(serialized);
+ final com.fasterxml.jackson.databind.JsonNode assistantMsg = root.get("messages").get(1);
+ assertTrue(assistantMsg.has("annotations"));
+ final com.fasterxml.jackson.databind.JsonNode annotations = assistantMsg.get("annotations");
+ assertEquals(1, annotations.size());
+ assertEquals("url_citation", annotations.get(0).get("type").asText());
+ assertEquals("https://example.com", annotations.get(0).get("url_citation").get("url").asText());
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
index 687c51233f95..916431437acd 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
@@ -21,25 +21,26 @@
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
-import org.apache.shenyu.common.enums.AiModelProviderEnum;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.utils.JsonUtils;
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
-import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
+import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService.FallbackContext;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.service.UpstreamErrorLogger;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
@@ -65,26 +66,18 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
*/
private static final long MAX_REQUEST_BODY_SIZE_BYTES = 5 * 1024 * 1024L;
- private final AiModelFactoryRegistry aiModelFactoryRegistry;
-
private final AiProxyConfigService aiProxyConfigService;
private final AiProxyExecutorService aiProxyExecutorService;
- private final ChatClientCache chatClientCache;
-
private final AiProxyPluginHandler aiProxyPluginHandler;
public AiProxyPlugin(
- final AiModelFactoryRegistry aiModelFactoryRegistry,
final AiProxyConfigService aiProxyConfigService,
final AiProxyExecutorService aiProxyExecutorService,
- final ChatClientCache chatClientCache,
final AiProxyPluginHandler aiProxyPluginHandler) {
- this.aiModelFactoryRegistry = aiModelFactoryRegistry;
this.aiProxyConfigService = aiProxyConfigService;
this.aiProxyExecutorService = aiProxyExecutorService;
- this.chatClientCache = chatClientCache;
this.aiProxyPluginHandler = aiProxyPluginHandler;
}
@@ -148,7 +141,8 @@ protected Mono doExecute(
}
}
- if (Boolean.TRUE.equals(primaryConfig.getStream())) {
+ final boolean stream = OpenAiProtocolAdapter.resolveStream(requestBody, primaryConfig.getStream());
+ if (stream) {
return handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
}
return handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
@@ -161,23 +155,26 @@ private Mono handleStreamRequest(
final String requestBody,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
- final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig);
- final String prompt = aiProxyConfigService.extractPrompt(requestBody);
- final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle,
- selector.getId(), requestBody);
+ final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig);
+ final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, true, primaryConfig);
+ final Optional fallbackCtx = resolveFallbackContext(
+ selector.getId(), primaryConfig, selectorHandle, requestBody);
final ServerHttpResponse response = exchange.getResponse();
response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM);
- final Flux chatResponseFlux = aiProxyExecutorService.executeStream(mainClient, fallbackClient,
- prompt);
+ final Flux chunkFlux = aiProxyExecutorService.executeDirectStream(
+ mainApi, fallbackCtx, request, requestBody, true);
- final Flux sseFlux = chatResponseFlux.map(
- chatResponse -> {
- final String json = JsonUtils.toJson(chatResponse);
+ final Flux sseFlux = chunkFlux.map(
+ chunk -> {
+ final String json = JsonUtils.toJson(chunk);
final String sseData = "data: " + json + "\n\n";
return response.bufferFactory()
.wrap(sseData.getBytes(StandardCharsets.UTF_8));
- });
+ }).concatWith(Mono.fromSupplier(() ->
+ response.bufferFactory()
+ .wrap("data: [DONE]\n\n".getBytes(StandardCharsets.UTF_8))))
+ .doOnError(e -> logUpstreamError(e, "stream"));
return response.writeWith(sseFlux);
}
@@ -188,24 +185,26 @@ private Mono handleNonStreamRequest(
final String requestBody,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
- final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig);
- final String prompt = aiProxyConfigService.extractPrompt(requestBody);
- final Optional fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle,
- selector.getId(), requestBody);
+ final OpenAiApi mainApi = getCachedOpenAiApi(selector.getId(), "main", primaryConfig);
+ final ChatCompletionRequest request = OpenAiProtocolAdapter.toChatCompletionRequest(requestBody, false, primaryConfig);
+ final Optional fallbackCtx = resolveFallbackContext(
+ selector.getId(), primaryConfig, selectorHandle, requestBody);
return aiProxyExecutorService
- .execute(mainClient, fallbackClient, prompt)
+ .executeDirectCall(mainApi, fallbackCtx, request, requestBody)
.flatMap(
- response -> {
- byte[] jsonBytes = JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8);
+ responseEntity -> {
+ final String responseJson = JsonUtils.toJson(responseEntity.getBody());
+ byte[] jsonBytes = responseJson.getBytes(StandardCharsets.UTF_8);
return WebFluxResultUtils.result(exchange, jsonBytes);
- });
+ })
+ .doOnError(e -> logUpstreamError(e, "non-stream"));
}
- private Optional resolveFallbackClient(
+ private Optional resolveFallbackContext(
+ final String selectorId,
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle,
- final String selectorId,
final String requestBody) {
return aiProxyConfigService
.resolveDynamicFallbackConfig(primaryConfig, requestBody)
@@ -214,86 +213,47 @@ private Optional resolveFallbackClient(
if (LOG.isDebugEnabled()) {
LOG.debug("[AiProxy] dynamic fallback config: {}", cfg);
}
- return createDynamicFallbackClient(cfg);
+ return new FallbackContext(createOpenAiApi(cfg), cfg);
})
- .or(
- () -> aiProxyConfigService
- .resolveAdminFallbackConfig(primaryConfig, selectorHandle)
- .map(adminFallbackConfig -> {
- LOG.info("[AiProxy] use admin fallback");
- if (LOG.isDebugEnabled()) {
- LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig);
- }
- return createAdminFallbackClient(selectorId, adminFallbackConfig);
- }));
+ .or(() -> aiProxyConfigService
+ .resolveAdminFallbackConfig(primaryConfig, selectorHandle)
+ .map(adminFallbackConfig -> {
+ LOG.info("[AiProxy] use admin fallback");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig);
+ }
+ return new FallbackContext(
+ getCachedOpenAiApi(selectorId, "adminFallback", adminFallbackConfig),
+ adminFallbackConfig);
+ }));
+ }
+
+ private OpenAiApi getCachedOpenAiApi(final String selectorId, final String type, final AiCommonConfig config) {
+ final String cacheKey = selectorId + "|" + type + "_" + generateConfigCacheKey(config);
+ return OpenAiApiCache.getInstance().computeIfAbsent(cacheKey, () -> createOpenAiApi(config));
}
- /**
- * Generate cache key based on config fields excluding apiKey.
- * This ensures cache consistency even when apiKey is updated at runtime.
- *
- * @param config the config
- * @return cache key hash
- */
private int generateConfigCacheKey(final AiCommonConfig config) {
return Objects.hash(
- config.getProvider(),
config.getBaseUrl(),
+ config.getApiKey(),
config.getModel(),
config.getTemperature(),
- config.getMaxTokens(),
- config.getStream()
- // Explicitly exclude apiKey to avoid cache misses when apiKey changes
+ config.getMaxTokens()
);
}
- private ChatClient createMainChatClient(final String selectorId, final AiCommonConfig config) {
- final int configHash = generateConfigCacheKey(config);
- final String cacheKey = selectorId + "|main_" + configHash;
- return chatClientCache.computeIfAbsent(
- cacheKey,
- () -> {
- LOG.info("Creating and caching main model for selector: {}, key: {}", selectorId, cacheKey);
- return createChatModel(config);
- });
- }
-
- private ChatClient createAdminFallbackClient(
- final String selectorId, final AiCommonConfig fallbackConfig) {
- final int configHash = generateConfigCacheKey(fallbackConfig);
- final String fallbackCacheKey = selectorId + "|adminFallback_" + configHash;
- return chatClientCache.computeIfAbsent(
- fallbackCacheKey,
- () -> {
- LOG.info(
- "Creating and caching admin fallback model for selector: {}, key: {}",
- selectorId, fallbackCacheKey);
- return createChatModel(fallbackConfig);
- });
- }
-
- private ChatClient createDynamicFallbackClient(final AiCommonConfig fallbackConfig) {
- LOG.info("Creating non-cached dynamic fallback model.");
- return ChatClient.builder(createChatModel(fallbackConfig)).build();
- }
-
- private ChatModel createChatModel(final AiCommonConfig config) {
- if (LOG.isDebugEnabled()) {
- LOG.debug("Creating chat model with config: {}", config);
- }
- final AiModelProviderEnum provider = AiModelProviderEnum.getByName(config.getProvider());
- if (Objects.isNull(provider)) {
- throw new IllegalArgumentException(
- "Invalid AI model provider in config: " + config.getProvider());
+ private OpenAiApi createOpenAiApi(final AiCommonConfig config) {
+ if (Objects.isNull(config.getBaseUrl()) || config.getBaseUrl().isEmpty()) {
+ throw new IllegalArgumentException("baseUrl must not be empty");
}
- final var factory = aiModelFactoryRegistry.getFactory(provider);
- if (Objects.isNull(factory)) {
- throw new IllegalArgumentException(
- "AI model factory not found for provider: " + provider.getName());
+ if (Objects.isNull(config.getApiKey()) || config.getApiKey().isEmpty()) {
+ throw new IllegalArgumentException("apiKey must not be empty");
}
- return Objects.requireNonNull(
- factory.createAiModel(config),
- "The AI model created by the factory must not be null");
+ return OpenAiApi.builder()
+ .baseUrl(config.getBaseUrl())
+ .apiKey(config.getApiKey())
+ .build();
}
@Override
@@ -305,4 +265,8 @@ public int getOrder() {
public String named() {
return PluginEnum.AI_PROXY.getName();
}
+
+ private void logUpstreamError(final Throwable e, final String mode) {
+ UpstreamErrorLogger.logUpstreamError(LOG, e, mode);
+ }
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
deleted file mode 100644
index 473ace98fa91..000000000000
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.shenyu.plugin.ai.proxy.enhanced.cache;
-
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.Map;
-import java.util.Objects;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Supplier;
-
-/**
- * This is ChatClient cache.
- */
-public final class ChatClientCache {
-
- private static final Logger LOG = LoggerFactory.getLogger(ChatClientCache.class);
-
- private static final int MAX_CACHE_SIZE = getCacheSize();
-
- private final Map chatClientMap = new ConcurrentHashMap<>();
-
- private final AtomicBoolean evictionInProgress = new AtomicBoolean(false);
-
- /**
- * Instantiates a new Chat client cache.
- */
- public ChatClientCache() {
- }
-
- private static int getCacheSize() {
- String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize",
- System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE"));
- if (Objects.nonNull(value)) {
- try {
- return Integer.parseInt(value);
- } catch (NumberFormatException e) {
- LoggerFactory.getLogger(ChatClientCache.class)
- .warn("[ChatClientCache] Invalid cache size '{}', using default 500.", value);
- }
- }
- return 500;
- }
-
- /**
- * Gets client or compute if absent.
- *
- * @param key the key
- * @param chatModelSupplier the chat model supplier
- * @return the chat client
- */
- public ChatClient computeIfAbsent(final String key, final Supplier chatModelSupplier) {
- // Check size before computing, but use synchronized block to prevent race conditions
- final int currentSize = chatClientMap.size();
- if (currentSize > MAX_CACHE_SIZE) {
- // Use atomic flag to ensure only one thread performs eviction
- if (evictionInProgress.compareAndSet(false, true)) {
- try {
- synchronized (chatClientMap) {
- // Double-check after acquiring lock
- if (chatClientMap.size() > MAX_CACHE_SIZE) {
- evictOldestEntries();
- }
- }
- } finally {
- evictionInProgress.set(false);
- }
- }
- }
- return chatClientMap.computeIfAbsent(key, k -> ChatClient.builder(chatModelSupplier.get()).build());
- }
-
- /**
- * Evict oldest entries when cache size exceeds limit.
- * Removes approximately 25% of entries to avoid thundering herd problem.
- */
- private void evictOldestEntries() {
- final int currentSize = chatClientMap.size();
- if (currentSize <= MAX_CACHE_SIZE) {
- return;
- }
-
- // Evict 25% of entries, but at least 10 entries
- final int evictCount = Math.max(10, currentSize / 4);
- LOG.warn("[ChatClientCache] Cache size {} exceeded limit {}, evicting {} oldest entries",
- currentSize, MAX_CACHE_SIZE, evictCount);
-
- // Since ConcurrentHashMap doesn't maintain insertion order,
- // we evict entries based on iteration order (which is somewhat arbitrary but better than clearing all)
- int removed = 0;
- for (final String key : chatClientMap.keySet()) {
- if (removed >= evictCount) {
- break;
- }
- chatClientMap.remove(key);
- removed++;
- }
-
- LOG.info("[ChatClientCache] Evicted {} entries, cache size now: {}", removed, chatClientMap.size());
- }
-
- /**
- * Removes all cached clients associated with a selector ID (by prefix matching
- * "selectorId|").
- *
- * @param selectorId the selector id
- */
- public void remove(final String selectorId) {
- if (java.util.Objects.isNull(selectorId)) {
- return;
- }
- final String prefix = selectorId + "|";
- chatClientMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix));
- LOG.info("[ChatClientCache] invalidate selectorId={} (by prefix)", selectorId);
- }
-
- /**
- * Clear all cached clients.
- */
- public void clearAll() {
- chatClientMap.clear();
- LOG.info("[ChatClientCache] cleared all cached clients");
- }
-}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
new file mode 100644
index 000000000000..d75aab480f32
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCache.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.cache;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.openai.api.OpenAiApi;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
+
+/**
+ * This is OpenAiApi cache.
+ * Caches OpenAiApi instances by config key to avoid recreating on every request.
+ */
+public final class OpenAiApiCache {
+
+ private static final Logger LOG = LoggerFactory.getLogger(OpenAiApiCache.class);
+
+ private static final OpenAiApiCache INSTANCE = new OpenAiApiCache();
+
+ private static final int MAX_CACHE_SIZE = getCacheSize();
+
+ private final Map openAiApiMap = new ConcurrentHashMap<>();
+
+ private final AtomicBoolean evictionInProgress = new AtomicBoolean(false);
+
+ /**
+ * Instantiates a new OpenAiApi cache.
+ */
+ public OpenAiApiCache() {
+ }
+
+ /**
+ * Gets instance.
+ *
+ * @return singleton
+ */
+ public static OpenAiApiCache getInstance() {
+ return INSTANCE;
+ }
+
+ private static int getCacheSize() {
+ String value = System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize",
+ System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE"));
+ if (Objects.nonNull(value)) {
+ try {
+ return Integer.parseInt(value);
+ } catch (NumberFormatException e) {
+ LOG.warn("[OpenAiApiCache] Invalid cache size '{}', using default 500.", value);
+ }
+ }
+ return 500;
+ }
+
+ /**
+ * Gets OpenAiApi or compute if absent.
+ *
+ * @param key the cache key (typically selectorId|configHash)
+ * @param openAiApiSupplier the supplier to create OpenAiApi if absent
+ * @return the cached or newly created OpenAiApi
+ */
+ public OpenAiApi computeIfAbsent(final String key, final Supplier openAiApiSupplier) {
+ final int currentSize = openAiApiMap.size();
+ if (currentSize > MAX_CACHE_SIZE) {
+ if (evictionInProgress.compareAndSet(false, true)) {
+ try {
+ synchronized (openAiApiMap) {
+ if (openAiApiMap.size() > MAX_CACHE_SIZE) {
+ evictEntries();
+ }
+ }
+ } finally {
+ evictionInProgress.set(false);
+ }
+ }
+ }
+ return openAiApiMap.computeIfAbsent(key, k -> openAiApiSupplier.get());
+ }
+
+ /**
+ * Evict arbitrary entries when cache size exceeds limit.
+ * Note: ConcurrentHashMap iteration order is undefined, so evicted entries are arbitrary,
+ * not guaranteed to be the oldest. Removes approximately 25% of entries to avoid
+ * thundering herd problem.
+ */
+ private void evictEntries() {
+ final int currentSize = openAiApiMap.size();
+ if (currentSize <= MAX_CACHE_SIZE) {
+ return;
+ }
+
+ final int evictCount = Math.max(10, currentSize / 4);
+ LOG.warn("[OpenAiApiCache] Cache size {} exceeded limit {}, evicting {} arbitrary entries",
+ currentSize, MAX_CACHE_SIZE, evictCount);
+
+ int removed = 0;
+ for (final String key : openAiApiMap.keySet()) {
+ if (removed >= evictCount) {
+ break;
+ }
+ openAiApiMap.remove(key);
+ removed++;
+ }
+
+ LOG.info("[OpenAiApiCache] Evicted {} entries, cache size now: {}", removed, openAiApiMap.size());
+ }
+
+ /**
+ * Removes all cached OpenAiApi instances associated with a selector ID
+ * (by prefix matching "selectorId|").
+ *
+ * @param selectorId the selector id
+ */
+ public void remove(final String selectorId) {
+ if (Objects.isNull(selectorId)) {
+ return;
+ }
+ final String prefix = selectorId + "|";
+ openAiApiMap.keySet().removeIf(k -> k.equals(selectorId) || k.startsWith(prefix));
+ LOG.info("[OpenAiApiCache] invalidate selectorId={} (by prefix)", selectorId);
+ }
+
+ /**
+ * Clear all cached OpenAiApi instances.
+ */
+ public void clearAll() {
+ openAiApiMap.clear();
+ LOG.info("[OpenAiApiCache] cleared all cached OpenAiApi instances");
+ }
+
+ /**
+ * Gets the current cache size (for testing / monitoring).
+ *
+ * @return the number of cached entries
+ */
+ public int size() {
+ return openAiApiMap.size();
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
index e4e3f470ed7c..40dfc31fa77f 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandler.java
@@ -23,7 +23,8 @@
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.utils.GsonUtils;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache;
import org.apache.shenyu.plugin.base.cache.CommonHandleCache;
import org.apache.shenyu.plugin.base.handler.PluginDataHandler;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
@@ -37,10 +38,7 @@ public class AiProxyPluginHandler implements PluginDataHandler {
private final CommonHandleCache selectorCachedHandle = new CommonHandleCache<>();
- private final ChatClientCache chatClientCache;
-
- public AiProxyPluginHandler(final ChatClientCache chatClientCache) {
- this.chatClientCache = chatClientCache;
+ public AiProxyPluginHandler() {
}
@Override
@@ -53,10 +51,13 @@ public void handlerPlugin(final PluginData pluginData) {
@Override
public void handlerSelector(final SelectorData selectorData) {
// Invalidate the cache first when the selector is updated.
- chatClientCache.remove(selectorData.getId());
+ OpenAiApiCache.getInstance().remove(selectorData.getId());
// Do NOT remove AiProxyApiKeyCache here. Admin will push updated AI_PROXY_API_KEY events
// with refreshed realApiKey after selector changes. Removing here introduces a window of misses.
- if (Objects.isNull(selectorData.getHandle())) {
+ if (Objects.isNull(selectorData.getHandle()) || selectorData.getHandle().isEmpty()) {
+ // Clear stale cached handle when selector handle is removed/cleared
+ selectorCachedHandle
+ .removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE));
return;
}
AiProxyHandle aiProxyHandle = GsonUtils.getInstance().fromJson(selectorData.getHandle(), AiProxyHandle.class);
@@ -68,7 +69,8 @@ public void handlerSelector(final SelectorData selectorData) {
@Override
public void removeSelector(final SelectorData selectorData) {
// Invalidate the cache when the selector is removed.
- chatClientCache.remove(selectorData.getId());
+ OpenAiApiCache.getInstance().remove(selectorData.getId());
+ AiProxyApiKeyCache.getInstance().removeBySelectorId(selectorData.getId());
selectorCachedHandle
.removeHandle(CacheKeyUtils.INST.getKey(selectorData.getId(), Constants.DEFAULT_RULE));
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
index 71ab0c4c49d1..534f81f0c7f7 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java
@@ -17,18 +17,24 @@
package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
-import org.apache.shenyu.plugin.ai.common.strategy.SimpleModelFallbackStrategy;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import org.apache.shenyu.plugin.ai.common.protocol.OpenAiProtocolAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.retry.NonTransientAiException;
+import org.springframework.http.ResponseEntity;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;
import java.time.Duration;
+import java.util.Objects;
import java.util.Optional;
/**
@@ -39,78 +45,111 @@ public class AiProxyExecutorService {
private static final Logger LOG = LoggerFactory.getLogger(AiProxyExecutorService.class);
/**
- * Execute the AI call with retry and fallback.
+ * Execute a streaming AI call directly via {@link OpenAiApi}, bypassing Spring AI's
+ * {@code createRequest()} which loses fields like {@code reasoning_content}.
*
- * @param mainClient the main chat client
- * @param fallbackClientOpt the optional fallback chat client
- * @param requestBody the request body
- * @return a Mono containing the ChatResponse
+ * @param mainApi the main OpenAiApi
+ * @param fallbackCtxOpt the optional fallback context (api + config)
+ * @param request the ChatCompletionRequest with all fields preserved
+ * @param requestBody the original request body for rebuilding fallback request
+ * @param stream whether this is a streaming request
+ * @return a Flux of ChatCompletionChunk
*/
- public Mono execute(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) {
- final Mono mainCall = doChatCall(mainClient, requestBody);
+ public Flux executeDirectStream(final OpenAiApi mainApi,
+ final Optional fallbackCtxOpt, final ChatCompletionRequest request,
+ final String requestBody, final boolean stream) {
+ return mainApi.chatCompletionStream(request)
+ .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct stream"))
+ .retryWhen(Retry.max(1)
+ .filter(AiProxyExecutorService::isRetryable)
+ .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> {
+ LOG.warn("Direct stream retry exhausted. Triggering fallback.",
+ retrySignal.failure());
+ return new NonTransientAiException(
+ "Direct stream failed after 1 retry. Triggering fallback.",
+ retrySignal.failure());
+ }))
+ .onErrorResume(e -> handleDirectFallbackStream(e, fallbackCtxOpt, requestBody, stream));
+ }
- return mainCall
+ /**
+ * Execute a non-streaming AI call directly via {@link OpenAiApi}.
+ *
+ * @param mainApi the main OpenAiApi
+ * @param fallbackCtxOpt the optional fallback context (api + config)
+ * @param request the ChatCompletionRequest with all fields preserved
+ * @param requestBody the original request body for rebuilding fallback request
+ * @return a Mono of ResponseEntity containing ChatCompletion
+ */
+ public Mono> executeDirectCall(final OpenAiApi mainApi,
+ final Optional fallbackCtxOpt, final ChatCompletionRequest request,
+ final String requestBody) {
+ return Mono.fromCallable(() -> mainApi.chatCompletionEntity(request))
+ .subscribeOn(Schedulers.boundedElastic())
+ .doOnError(e -> UpstreamErrorLogger.logUpstreamError(LOG, e, "direct call"))
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1))
- .filter(throwable -> !(throwable instanceof NonTransientAiException))
+ .filter(AiProxyExecutorService::isRetryable)
.onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> {
- LOG.warn("Retries exhausted for AI call after {} attempts.",
+ LOG.warn("Direct call retries exhausted after {} attempts. Triggering fallback.",
retrySignal.totalRetries(), retrySignal.failure());
- return new NonTransientAiException("Retries exhausted. Triggering fallback.",
+ return new NonTransientAiException("Direct call retries exhausted. Triggering fallback.",
retrySignal.failure());
}))
- .onErrorResume(NonTransientAiException.class,
- throwable -> handleFallback(throwable, fallbackClientOpt, requestBody));
- }
-
- protected Mono doChatCall(final ChatClient client, final String requestBody) {
- return Mono.fromCallable(() -> client.prompt().user(requestBody).call().chatResponse())
- .subscribeOn(Schedulers.boundedElastic());
+ .onErrorResume(e -> handleDirectFallbackCall(e, fallbackCtxOpt, requestBody));
}
- private Mono handleFallback(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) {
- LOG.warn("AI main call failed or retries exhausted, attempting to fallback...", throwable);
+ private Flux handleDirectFallbackStream(final Throwable throwable,
+ final Optional fallbackCtxOpt, final String requestBody, final boolean stream) {
+ LOG.warn("Main direct stream failed, attempting fallback...", throwable);
- if (fallbackClientOpt.isEmpty()) {
- return Mono.error(throwable);
+ if (fallbackCtxOpt.isEmpty()) {
+ return Flux.error(throwable);
}
- return SimpleModelFallbackStrategy.INSTANCE.fallback(fallbackClientOpt.get(), requestBody, throwable);
+ final FallbackContext ctx = fallbackCtxOpt.get();
+ LOG.info("Using fallback OpenAiApi for direct stream");
+ final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest(
+ requestBody, stream, ctx.config());
+ return ctx.api().chatCompletionStream(fallbackRequest);
}
- /**
- * Execute the AI call with retry and fallback.
- *
- * @param mainClient the main chat client
- * @param fallbackClientOpt the optional fallback chat client
- * @param requestBody the request body
- * @return a Flux containing the ChatResponse
- */
- public Flux executeStream(final ChatClient mainClient, final Optional fallbackClientOpt, final String requestBody) {
- final Flux mainStream = doChatStream(mainClient, requestBody);
+ private Mono> handleDirectFallbackCall(final Throwable throwable,
+ final Optional fallbackCtxOpt, final String requestBody) {
+ LOG.warn("Main direct call failed, attempting fallback...", throwable);
- return mainStream
- .retryWhen(Retry.max(1)
- .onRetryExhaustedThrow((retryBackoffSpec, retrySignal) -> {
- LOG.warn("Retrying stream once failed. Attempts: {}. Triggering fallback.",
- retrySignal.totalRetries(), retrySignal.failure());
- return new NonTransientAiException("Stream failed after 1 retry. Triggering fallback.", retrySignal.failure());
- }))
- .onErrorResume(NonTransientAiException.class,
- throwable -> handleFallbackStream(throwable, fallbackClientOpt, requestBody));
- }
+ if (fallbackCtxOpt.isEmpty()) {
+ return Mono.error(throwable);
+ }
- protected Flux doChatStream(final ChatClient client, final String requestBody) {
- return Flux.defer(() -> client.prompt().user(requestBody).stream().chatResponse())
+ final FallbackContext ctx = fallbackCtxOpt.get();
+ LOG.info("Using fallback OpenAiApi for direct call");
+ final ChatCompletionRequest fallbackRequest = OpenAiProtocolAdapter.toChatCompletionRequest(
+ requestBody, false, ctx.config());
+ return Mono.fromCallable(() -> ctx.api().chatCompletionEntity(fallbackRequest))
.subscribeOn(Schedulers.boundedElastic());
}
- private Flux handleFallbackStream(final Throwable throwable, final Optional fallbackClientOpt, final String requestBody) {
- LOG.warn("AI main stream failed or retries exhausted, attempting to fallback...", throwable);
-
- if (fallbackClientOpt.isEmpty()) {
- return Flux.error(throwable);
+ /**
+ * Determine if the error is retryable.
+ * Retries transient network errors and retryable HTTP statuses (429, 5xx).
+ * Non-retryable: NonTransientAiException, client errors (400/401/403/404).
+ */
+ private static boolean isRetryable(final Throwable throwable) {
+ if (throwable instanceof NonTransientAiException) {
+ return false;
}
+ final WebClientResponseException webClientEx = UpstreamErrorLogger.findWebClientResponseException(throwable);
+ if (Objects.nonNull(webClientEx)) {
+ final int status = webClientEx.getStatusCode().value();
+ return status == 429 || status >= 500;
+ }
+ return true;
+ }
- return SimpleModelFallbackStrategy.INSTANCE.fallbackStream(fallbackClientOpt.get(), requestBody, throwable);
+ /**
+ * Container for fallback context: the OpenAiApi instance and the fallback config
+ * used to rebuild the request with fallback-specific model/temperature/maxTokens.
+ */
+ public record FallbackContext(OpenAiApi api, AiCommonConfig config) {
}
-}
\ No newline at end of file
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
new file mode 100644
index 000000000000..06a08e0c378a
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLogger.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+
+import org.slf4j.Logger;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+
+import java.util.Objects;
+
+/**
+ * Shared utility for logging upstream AI service errors with WebClientResponseException details.
+ */
+public final class UpstreamErrorLogger {
+
+ private static final int MAX_BODY_LOG_LENGTH = 512;
+
+ private UpstreamErrorLogger() {
+ }
+
+ public static void logUpstreamError(final Logger log, final Throwable e, final String mode) {
+ if (Objects.isNull(e)) {
+ return;
+ }
+ final WebClientResponseException webClientEx = findWebClientResponseException(e);
+ if (Objects.nonNull(webClientEx)) {
+ log.error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ mode, webClientEx.getStatusCode(), truncateBody(webClientEx.getResponseBodyAsString()), e);
+ } else {
+ log.error("[AiProxy] {} failed", mode, e);
+ }
+ }
+
+ private static String truncateBody(final String body) {
+ if (Objects.isNull(body)) {
+ return "null";
+ }
+ if (body.length() <= MAX_BODY_LOG_LENGTH) {
+ return body;
+ }
+ return body.substring(0, MAX_BODY_LOG_LENGTH) + "...(truncated, total " + body.length() + " chars)";
+ }
+
+ /**
+ * Find the {@link WebClientResponseException} in the exception chain.
+ *
+ * @param e the root exception
+ * @return the WebClientResponseException if found, null otherwise
+ */
+ public static WebClientResponseException findWebClientResponseException(final Throwable e) {
+ Throwable current = e;
+ while (Objects.nonNull(current)) {
+ if (current instanceof WebClientResponseException ex) {
+ return ex;
+ }
+ current = current.getCause();
+ }
+ return null;
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
index 4f63f5b53bb8..bde74f48668e 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriber.java
@@ -19,7 +19,7 @@
import org.apache.shenyu.common.dto.ProxyApiKeyData;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.OpenAiApiCache;
import org.apache.shenyu.sync.data.api.AiProxyApiKeyDataSubscriber;
import java.util.Objects;
@@ -29,12 +29,6 @@
*/
public final class CommonAiProxyApiKeyDataSubscriber implements AiProxyApiKeyDataSubscriber {
- private final ChatClientCache chatClientCache;
-
- public CommonAiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) {
- this.chatClientCache = chatClientCache;
- }
-
@Override
public void onSubscribe(final ProxyApiKeyData data) {
if (Objects.isNull(data) || Objects.isNull(data.getProxyApiKey())) {
@@ -54,6 +48,6 @@ public void unSubscribe(final ProxyApiKeyData data) {
@Override
public void refresh() {
AiProxyApiKeyCache.getInstance().refresh();
- chatClientCache.clearAll();
+ OpenAiApiCache.getInstance().clearAll();
}
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
index 30e84abc0809..e61f43e1689d 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
@@ -23,10 +23,7 @@
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
import org.apache.shenyu.common.enums.AiModelProviderEnum;
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
-import org.apache.shenyu.plugin.ai.common.spring.ai.AiModelFactory;
-import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
@@ -43,21 +40,22 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.context.ApplicationContext;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import java.util.Optional;
import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -75,27 +73,12 @@ public class AiProxyPluginTest {
private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}";
- @Mock
- private AiModelFactoryRegistry registry;
-
@Mock
private AiProxyConfigService configService;
@Mock
private AiProxyExecutorService executorService;
- @Mock
- private ChatClientCache chatClientCache;
-
- @Mock
- private AiModelFactory modelFactory;
-
- @Mock
- private ChatModel chatModel;
-
- @Mock
- private ChatClient chatClient;
-
private AiProxyPluginHandler aiProxyPluginHandler;
private AiProxyPlugin plugin;
@@ -110,8 +93,8 @@ public class AiProxyPluginTest {
@BeforeEach
public void setUp() {
- aiProxyPluginHandler = new AiProxyPluginHandler(chatClientCache);
- plugin = new AiProxyPlugin(registry, configService, executorService, chatClientCache, aiProxyPluginHandler);
+ aiProxyPluginHandler = new AiProxyPluginHandler();
+ plugin = new AiProxyPlugin(configService, executorService, aiProxyPluginHandler);
selector = new SelectorData();
selector.setId(SELECTOR_ID);
@@ -129,14 +112,6 @@ public void setUp() {
when(applicationContext.getBean(ShenyuResult.class)).thenReturn(resultMock);
SpringBeanUtils.getInstance().setApplicationContext(applicationContext);
- // Common mock behavior for model creation and cache
- when(registry.getFactory(any(AiModelProviderEnum.class))).thenReturn(modelFactory);
- when(modelFactory.createAiModel(any(AiCommonConfig.class))).thenReturn(chatModel);
- when(chatClientCache.computeIfAbsent(anyString(), any())).thenAnswer(invocation -> {
- // Always return the mock ChatClient to avoid any UnsupportedOperationException
- return chatClient;
- });
-
// mock static
apiKeyCacheMockedStatic = mockStatic(AiProxyApiKeyCache.class);
}
@@ -150,13 +125,14 @@ public void tearDown() {
private void setupSuccessMocks(final AiProxyHandle handle, final AiCommonConfig primaryConfig, final Optional fallbackConfig) {
aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle);
- final ChatResponse chatResponse = mock(ChatResponse.class);
+ final ChatCompletion chatCompletion = mock(ChatCompletion.class);
+ final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion);
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(fallbackConfig);
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(fallbackConfig);
- when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0));
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity));
+ when(executorService.executeDirectStream(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class), any(Boolean.class))).thenReturn(Flux.empty());
}
@Test
@@ -164,6 +140,8 @@ public void testExecuteSuccess() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
@@ -173,7 +151,7 @@ public void testExecuteSuccess() {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).execute(any(), any(), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
@@ -181,8 +159,12 @@ public void testExecuteWithDynamicFallback() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final AiCommonConfig fallbackConfig = new AiCommonConfig();
fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName());
+ fallbackConfig.setBaseUrl("https://api.deepseek.com");
+ fallbackConfig.setApiKey("fallback-key");
setupSuccessMocks(handle, primaryConfig, Optional.of(fallbackConfig));
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.of(fallbackConfig));
@@ -190,7 +172,7 @@ public void testExecuteWithDynamicFallback() {
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).execute(any(ChatClient.class), any(Optional.class), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
@@ -199,6 +181,7 @@ public void testExecuteWithValidProxyApiKey() {
handle.setProxyEnabled("true");
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
primaryConfig.setApiKey("original-key");
// setup request with proxy key
@@ -225,6 +208,8 @@ public void testExecuteWithInvalidProxyApiKey() {
handle.setProxyEnabled("true");
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
// setup request with proxy key
exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/test").header(Constants.X_API_KEY, "proxy-key-invalid").body(REQUEST_BODY));
@@ -252,8 +237,12 @@ public void testExecuteWithAdminFallback() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final AiCommonConfig fallbackConfig = new AiCommonConfig();
fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName());
+ fallbackConfig.setBaseUrl("https://api.deepseek.com");
+ fallbackConfig.setApiKey("fallback-key");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty());
@@ -262,61 +251,57 @@ public void testExecuteWithAdminFallback() {
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).execute(any(ChatClient.class), any(Optional.class), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
- public void testCacheIsUsedForAdminFallbackClient() {
+ public void testAdminFallbackExecution() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final AiCommonConfig fallbackConfig = new AiCommonConfig();
fallbackConfig.setProvider(AiModelProviderEnum.DEEP_SEEK.getName());
+ fallbackConfig.setBaseUrl("https://api.deepseek.com");
+ fallbackConfig.setApiKey("fallback-key");
// Cache the handle for the test
aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle);
- final ChatResponse chatResponse = mock(ChatResponse.class);
+ final ChatCompletion chatCompletion = mock(ChatCompletion.class);
+ final ResponseEntity responseEntity = ResponseEntity.ok(chatCompletion);
// Setup all necessary mocks
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty());
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.of(fallbackConfig));
- when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0));
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.just(chatResponse));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.just(responseEntity));
- // Execute the test - focus on successful execution rather than cache verification
+ // Execute the test - verify admin fallback path is exercised
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
-
+
// Verify that the configuration methods were called correctly
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).execute(any(), any(), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
@Test
- public void testCreateChatModelThrowsException() {
+ public void testCreateOpenAiApiThrowsException() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
- primaryConfig.setProvider("InvalidProvider");
-
- // Cache the handle for the test
+ primaryConfig.setBaseUrl(null);
+
aiProxyPluginHandler.getSelectorCachedHandle().cachedHandle(CacheKeyUtils.INST.getKey(SELECTOR_ID, Constants.DEFAULT_RULE), handle);
-
- // Mock config service to return the invalid config
+
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY)).thenReturn(Optional.empty());
when(configService.resolveAdminFallbackConfig(primaryConfig, handle)).thenReturn(Optional.empty());
- when(configService.extractPrompt(anyString())).thenAnswer(invocation -> invocation.getArgument(0));
-
- // Mock registry to return null factory for invalid provider - this should cause IllegalArgumentException
- when(registry.getFactory(any())).thenReturn(null);
-
- // Mock executorService to return a proper Mono to avoid NullPointerException
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(new IllegalArgumentException("AI model factory not found")));
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
- .expectError(IllegalArgumentException.class)
+ .expectErrorMatches(e -> e instanceof IllegalArgumentException
+ && e.getMessage().contains("baseUrl must not be empty"))
.verify();
}
@@ -325,10 +310,12 @@ public void testExecutorServiceError() {
final AiProxyHandle handle = new AiProxyHandle();
final AiCommonConfig primaryConfig = new AiCommonConfig();
primaryConfig.setProvider(AiModelProviderEnum.OPEN_AI.getName());
+ primaryConfig.setBaseUrl("https://api.openai.com");
+ primaryConfig.setApiKey("test-key");
final RuntimeException exception = new RuntimeException("AI execution failed");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
- when(executorService.execute(any(), any(), any())).thenReturn(Mono.error(exception));
+ when(executorService.executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class))).thenReturn(Mono.error(exception));
StepVerifier.create(plugin.doExecute(exchange, mock(ShenyuPluginChain.class), selector, rule))
.expectErrorMatches(exception::equals)
@@ -337,6 +324,6 @@ public void testExecutorServiceError() {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig, REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig, handle);
- verify(executorService).execute(any(), any(), any());
+ verify(executorService).executeDirectCall(any(OpenAiApi.class), any(Optional.class), any(ChatCompletionRequest.class), any(String.class));
}
}
\ No newline at end of file
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java
new file mode 100644
index 000000000000..11bed2897ddb
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/OpenAiApiCacheTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.cache;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.openai.api.OpenAiApi;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertSame;
+
+class OpenAiApiCacheTest {
+
+ private OpenAiApiCache cache;
+
+ @BeforeEach
+ void setUp() {
+ cache = new OpenAiApiCache();
+ cache.clearAll();
+ }
+
+ @Test
+ void testComputeIfAbsentReusesExisting() {
+ final OpenAiApi api = createTestApi("https://api.openai.com", "key1");
+ final OpenAiApi cached = cache.computeIfAbsent("key1", () -> api);
+ final OpenAiApi reused = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1-other"));
+
+ assertSame(cached, reused);
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ void testComputeIfAbsentCreatesNewForDifferentKey() {
+ final OpenAiApi api1 = cache.computeIfAbsent("key1", () -> createTestApi("https://api.openai.com", "key1"));
+ final OpenAiApi api2 = cache.computeIfAbsent("key2", () -> createTestApi("https://api.deepseek.com", "key2"));
+
+ // Different keys produce different instances
+ assertEquals(2, cache.size());
+ }
+
+ @Test
+ void testRemoveBySelectorId() {
+ cache.computeIfAbsent("sel1|main_123", () -> createTestApi("https://a.com", "k1"));
+ cache.computeIfAbsent("sel1|adminFallback_456", () -> createTestApi("https://b.com", "k2"));
+ cache.computeIfAbsent("sel2|main_789", () -> createTestApi("https://c.com", "k3"));
+
+ assertEquals(3, cache.size());
+
+ cache.remove("sel1");
+
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ void testRemoveWithNullSelectorIdDoesNotThrow() {
+ cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1"));
+ cache.remove(null);
+ assertEquals(1, cache.size());
+ }
+
+ @Test
+ void testClearAll() {
+ cache.computeIfAbsent("key1", () -> createTestApi("https://a.com", "k1"));
+ cache.computeIfAbsent("key2", () -> createTestApi("https://b.com", "k2"));
+
+ assertEquals(2, cache.size());
+
+ cache.clearAll();
+
+ assertEquals(0, cache.size());
+ }
+
+ private OpenAiApi createTestApi(final String baseUrl, final String apiKey) {
+ return OpenAiApi.builder().baseUrl(baseUrl).apiKey(apiKey).build();
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
new file mode 100644
index 000000000000..5aa2a7a5edf2
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/handler/AiProxyPluginHandlerTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.handler;
+
+import org.apache.shenyu.common.constant.Constants;
+import org.apache.shenyu.common.dto.SelectorData;
+import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
+import org.apache.shenyu.common.enums.PluginEnum;
+import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+class AiProxyPluginHandlerTest {
+
+ private AiProxyPluginHandler handler;
+
+ @BeforeEach
+ void setUp() {
+ handler = new AiProxyPluginHandler();
+ }
+
+ @Test
+ void testPluginNamed() {
+ assertEquals(PluginEnum.AI_PROXY.getName(), handler.pluginNamed());
+ }
+
+ @Test
+ void testHandlerSelectorCachesHandle() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-1");
+ selector.setHandle("{\"provider\":\"open_ai\",\"baseUrl\":\"https://api.openai.com\",\"apiKey\":\"sk-test\",\"model\":\"gpt-4\"}");
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-1", Constants.DEFAULT_RULE);
+ AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key);
+ assertNotNull(cached);
+ assertEquals("open_ai", cached.getProvider());
+ assertEquals("https://api.openai.com", cached.getBaseUrl());
+ assertEquals("sk-test", cached.getApiKey());
+ assertEquals("gpt-4", cached.getModel());
+ }
+
+ @Test
+ void testHandlerSelectorWithNullHandle() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-2");
+ selector.setHandle(null);
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-2", Constants.DEFAULT_RULE);
+ assertNull(handler.getSelectorCachedHandle().obtainHandle(key));
+ }
+
+ @Test
+ void testHandlerSelectorWithEmptyHandle() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-3");
+ selector.setHandle("");
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-3", Constants.DEFAULT_RULE);
+ assertNull(handler.getSelectorCachedHandle().obtainHandle(key));
+ }
+
+ @Test
+ void testHandlerSelectorWithFallbackNormalize() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-4");
+ selector.setHandle("{\"provider\":\"open_ai\",\"fallbackEnabled\":\"true\",\"fallbackModel\":\"fallback-model\"}");
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-4", Constants.DEFAULT_RULE);
+ AiProxyHandle cached = handler.getSelectorCachedHandle().obtainHandle(key);
+ assertNotNull(cached);
+ assertNotNull(cached.getFallbackConfig());
+ assertEquals("fallback-model", cached.getFallbackConfig().getModel());
+ }
+
+ @Test
+ void testRemoveSelector() {
+ SelectorData selector = new SelectorData();
+ selector.setId("sel-5");
+ selector.setHandle("{\"provider\":\"open_ai\",\"model\":\"gpt-4\"}");
+
+ handler.handlerSelector(selector);
+
+ String key = CacheKeyUtils.INST.getKey("sel-5", Constants.DEFAULT_RULE);
+ assertNotNull(handler.getSelectorCachedHandle().obtainHandle(key));
+
+ handler.removeSelector(selector);
+ assertNull(handler.getSelectorCachedHandle().obtainHandle(key));
+ }
+
+ @Test
+ void testHandlerPluginDoesNothing() {
+ handler.handlerPlugin(null);
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
index 7bd2bbbdac0d..38597d0dbedf 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorServiceTest.java
@@ -17,24 +17,24 @@
package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
-import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.retry.NonTransientAiException;
-import org.springframework.web.client.RestClientException;
+import org.springframework.http.ResponseEntity;
+import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
import java.util.Optional;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -42,102 +42,106 @@
@ExtendWith(MockitoExtension.class)
public class AiProxyExecutorServiceTest {
- @Mock
- private ChatModel mainChatModel;
-
- @Mock
- private ChatModel fallbackChatModel;
-
- private ChatClient mainClient;
-
- private ChatClient fallbackClient;
+ private static final String REQUEST_BODY = "{\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}";
private AiProxyExecutorService executorService;
@BeforeEach
void setUp() {
- executorService = spy(new AiProxyExecutorService());
- mainClient = ChatClient.create(mainChatModel);
- fallbackClient = ChatClient.create(fallbackChatModel);
+ executorService = new AiProxyExecutorService();
}
@Test
- void testExecuteSuccessOnFirstAttempt() {
- final ChatResponse successResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class))).thenReturn(successResponse);
-
- StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request"))
- .expectNext(successResponse)
+ void testExecuteDirectStreamSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionChunk chunk = mock(ChatCompletionChunk.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ when(mainApi.chatCompletionStream(request)).thenReturn(Flux.just(chunk));
+
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true))
+ .expectNext(chunk)
.verifyComplete();
- verify(mainChatModel, times(1)).call(any(Prompt.class));
+ verify(mainApi, times(1)).chatCompletionStream(request);
}
@Test
- void testExecuteRetryOnRestClientExceptionAndSucceed() {
- final ChatResponse successResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class)))
- .thenThrow(new RestClientException("transient error"))
- .thenReturn(successResponse);
-
- StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request"))
- .expectNext(successResponse)
- .verifyComplete();
+ void testExecuteDirectStreamErrorWithNoFallback() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error")));
- verify(mainChatModel, times(2)).call(any(Prompt.class));
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.empty(), request, REQUEST_BODY, true))
+ .expectError(NonTransientAiException.class)
+ .verify();
}
@Test
- void testExecuteRetryExhaustedThenFallback() {
- final ChatResponse fallbackResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class))).thenThrow(new RestClientException("transient error"));
- when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse);
+ void testExecuteDirectStreamFallbackSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final OpenAiApi fallbackApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ final ChatCompletionChunk fallbackChunk = mock(ChatCompletionChunk.class);
- StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request"))
- .expectNext(fallbackResponse)
+ when(mainApi.chatCompletionStream(request)).thenAnswer(inv -> Flux.error(new RuntimeException("upstream error")));
+ when(fallbackApi.chatCompletionStream(any(ChatCompletionRequest.class))).thenReturn(Flux.just(fallbackChunk));
+
+ final AiCommonConfig fallbackConfig = new AiCommonConfig();
+ fallbackConfig.setModel("fallback-model");
+ final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig);
+
+ StepVerifier.create(executorService.executeDirectStream(mainApi, Optional.of(ctx), request, REQUEST_BODY, true))
+ .expectNext(fallbackChunk)
.verifyComplete();
- verify(mainChatModel, times(4)).call(any(Prompt.class));
- verify(fallbackChatModel, times(1)).call(any(Prompt.class));
+ verify(fallbackApi, times(1)).chatCompletionStream(any(ChatCompletionRequest.class));
}
@Test
- void testExecuteNonTransientExceptionTriggersFallbackDirectly() {
- final ChatResponse fallbackResponse = mock(ChatResponse.class);
- when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error"));
- when(fallbackChatModel.call(any(Prompt.class))).thenReturn(fallbackResponse);
-
- StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request"))
- .expectNext(fallbackResponse)
+ void testExecuteDirectCallSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ final ChatCompletion completion = mock(ChatCompletion.class);
+ final ResponseEntity responseEntity = ResponseEntity.ok(completion);
+ when(mainApi.chatCompletionEntity(request)).thenReturn(responseEntity);
+
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY))
+ .expectNext(responseEntity)
.verifyComplete();
- verify(mainChatModel, times(1)).call(any(Prompt.class));
- verify(fallbackChatModel, times(1)).call(any(Prompt.class));
+ verify(mainApi, times(1)).chatCompletionEntity(request);
}
@Test
- void testExecuteFallbackFails() {
- final RestClientException fallbackException = new RestClientException("fallback failed");
- when(mainChatModel.call(any(Prompt.class))).thenThrow(new NonTransientAiException("non-transient error"));
- when(fallbackChatModel.call(any(Prompt.class))).thenThrow(fallbackException);
+ void testExecuteDirectCallErrorWithNoFallback() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error"));
- StepVerifier.create(executorService.execute(mainClient, Optional.of(fallbackClient), "request"))
- .expectErrorMatches(e -> e == fallbackException)
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.empty(), request, REQUEST_BODY))
+ .expectError(NonTransientAiException.class)
.verify();
-
- verify(mainChatModel, times(1)).call(any(Prompt.class));
- verify(fallbackChatModel, times(1)).call(any(Prompt.class));
}
@Test
- void testExecuteNoFallbackProvided() {
- final NonTransientAiException mainException = new NonTransientAiException("non-transient error");
- when(mainChatModel.call(any(Prompt.class))).thenThrow(mainException);
+ void testExecuteDirectCallFallbackSuccess() {
+ final OpenAiApi mainApi = mock(OpenAiApi.class);
+ final OpenAiApi fallbackApi = mock(OpenAiApi.class);
+ final ChatCompletionRequest request = mock(ChatCompletionRequest.class);
+ final ChatCompletion fallbackCompletion = mock(ChatCompletion.class);
+ final ResponseEntity fallbackResponse = ResponseEntity.ok(fallbackCompletion);
- StepVerifier.create(executorService.execute(mainClient, Optional.empty(), "request"))
- .expectErrorMatches(e -> e == mainException)
- .verify();
+ when(mainApi.chatCompletionEntity(request)).thenThrow(new RuntimeException("upstream error"));
+ when(fallbackApi.chatCompletionEntity(any(ChatCompletionRequest.class))).thenReturn(fallbackResponse);
+
+ final AiCommonConfig fallbackConfig = new AiCommonConfig();
+ fallbackConfig.setModel("fallback-model");
+ final AiProxyExecutorService.FallbackContext ctx = new AiProxyExecutorService.FallbackContext(fallbackApi, fallbackConfig);
+
+ StepVerifier.create(executorService.executeDirectCall(mainApi, Optional.of(ctx), request, REQUEST_BODY))
+ .expectNext(fallbackResponse)
+ .verifyComplete();
- verify(mainChatModel, times(1)).call(any(Prompt.class));
+ verify(fallbackApi, times(1)).chatCompletionEntity(any(ChatCompletionRequest.class));
}
}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
new file mode 100644
index 000000000000..85193687d1df
--- /dev/null
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/UpstreamErrorLoggerTest.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.springframework.http.HttpStatus;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+class UpstreamErrorLoggerTest {
+
+ @Test
+ void testLogWithWebClientResponseException() {
+ Logger log = mock(Logger.class);
+ WebClientResponseException ex = WebClientResponseException.create(
+ 429, "Too Many Requests", null, "rate limited".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8);
+
+ UpstreamErrorLogger.logUpstreamError(log, ex, "stream");
+
+ verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ "stream", HttpStatus.TOO_MANY_REQUESTS, "rate limited", ex);
+ }
+
+ @Test
+ void testLogWithWrappedWebClientResponseException() {
+ Logger log = mock(Logger.class);
+ WebClientResponseException inner = WebClientResponseException.create(
+ 500, "Internal Server Error", null, "server error".getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8);
+ RuntimeException wrapped = new RuntimeException("call failed", inner);
+
+ UpstreamErrorLogger.logUpstreamError(log, wrapped, "non-stream");
+
+ verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ "non-stream", HttpStatus.INTERNAL_SERVER_ERROR, "server error", wrapped);
+ }
+
+ @Test
+ void testLogWithGenericException() {
+ Logger log = mock(Logger.class);
+ RuntimeException ex = new RuntimeException("generic error");
+
+ UpstreamErrorLogger.logUpstreamError(log, ex, "stream");
+
+ verify(log).error("[AiProxy] {} failed", "stream", ex);
+ }
+
+ @Test
+ void testLogWithNullExceptionDoesNotThrow() {
+ Logger log = mock(Logger.class);
+ assertDoesNotThrow(() -> UpstreamErrorLogger.logUpstreamError(log, null, "stream"));
+ }
+
+ @Test
+ void testLogWithLongBodyIsTruncated() {
+ Logger log = mock(Logger.class);
+ final String longBody = "x".repeat(1024);
+ WebClientResponseException ex = WebClientResponseException.create(
+ 500, "Internal Server Error", null, longBody.getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8);
+
+ UpstreamErrorLogger.logUpstreamError(log, ex, "stream");
+
+ final String expectedTruncated = longBody.substring(0, 512)
+ + "...(truncated, total " + longBody.length() + " chars)";
+ verify(log).error("[AiProxy] {} failed, status={}, upstreamBody={}",
+ "stream", HttpStatus.INTERNAL_SERVER_ERROR, expectedTruncated, ex);
+ }
+}
diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java
index c863ba20f48a..4827c5337240 100644
--- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java
+++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/subscriber/CommonAiProxyApiKeyDataSubscriberTest.java
@@ -19,17 +19,12 @@
import org.apache.shenyu.common.dto.ProxyApiKeyData;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-import static org.mockito.Mockito.mock;
-
class CommonAiProxyApiKeyDataSubscriberTest {
- private final ChatClientCache chatClientCache = mock(ChatClientCache.class);
-
@AfterEach
void cleanup() {
AiProxyApiKeyCache.getInstance().refresh();
@@ -37,7 +32,7 @@ void cleanup() {
@Test
void testOnSubscribeCachesEnabled() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-sub")
@@ -51,7 +46,7 @@ void testOnSubscribeCachesEnabled() {
@Test
void testOnSubscribeIgnoresDisabled() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-disabled")
@@ -65,7 +60,7 @@ void testOnSubscribeIgnoresDisabled() {
@Test
void testUnSubscribeRemoves() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-unsub")
@@ -81,7 +76,7 @@ void testUnSubscribeRemoves() {
@Test
void testRefreshClearsCache() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("proxy-refresh")
@@ -98,14 +93,14 @@ void testRefreshClearsCache() {
@Test
void testOnSubscribeNullDataNoThrow() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
subscriber.onSubscribe(null);
Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size());
}
@Test
void testOnSubscribeNullKeyIgnored() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey(null)
@@ -119,14 +114,14 @@ void testOnSubscribeNullKeyIgnored() {
@Test
void testUnSubscribeNullDataNoThrow() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
subscriber.unSubscribe(null);
Assertions.assertEquals(0, AiProxyApiKeyCache.getInstance().size());
}
@Test
void testUnSubscribeNullKeyNoThrow() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData data = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey(null)
@@ -140,7 +135,7 @@ void testUnSubscribeNullKeyNoThrow() {
@Test
void testDuplicateSubscribeOverrides() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber();
ProxyApiKeyData v1 = ProxyApiKeyData.builder()
.selectorId("test-selector")
.proxyApiKey("dup-key")
@@ -161,48 +156,4 @@ void testDuplicateSubscribeOverrides() {
subscriber.onSubscribe(v2);
Assertions.assertEquals("real-2", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "dup-key"));
}
-
- @Test
- void testSubscribeDisabledDoesNotOverrideExisting() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
- ProxyApiKeyData enabled = ProxyApiKeyData.builder()
- .selectorId("test-selector")
- .proxyApiKey("keep-key")
- .realApiKey("real-keep")
- .enabled(Boolean.TRUE)
- .namespaceId("default")
- .build();
- subscriber.onSubscribe(enabled);
- Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key"));
-
- ProxyApiKeyData disabled = ProxyApiKeyData.builder()
- .selectorId("test-selector")
- .proxyApiKey("keep-key")
- .realApiKey("real-new")
- .enabled(Boolean.FALSE)
- .namespaceId("default")
- .build();
- subscriber.onSubscribe(disabled);
- // disabled subscribe should not override existing cached mapping
- Assertions.assertEquals("real-keep", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", "keep-key"));
- }
-
- @Test
- void testVeryLongProxyKey() {
- CommonAiProxyApiKeyDataSubscriber subscriber = new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
- StringBuilder sb = new StringBuilder();
- for (int i = 0; i < 300; i++) {
- sb.append('A' + (i % 26));
- }
- String longKey = sb.toString();
- ProxyApiKeyData data = ProxyApiKeyData.builder()
- .selectorId("test-selector")
- .proxyApiKey(longKey)
- .realApiKey("real-long")
- .enabled(Boolean.TRUE)
- .namespaceId("default")
- .build();
- subscriber.onSubscribe(data);
- Assertions.assertEquals("real-long", AiProxyApiKeyCache.getInstance().getRealApiKey("test-selector", longKey));
- }
-}
\ No newline at end of file
+}
diff --git a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java
index d017cfabd8fc..2bcb70aa1dae 100644
--- a/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java
+++ b/shenyu-spring-boot-starter/shenyu-spring-boot-starter-plugin/shenyu-spring-boot-starter-plugin-ai-proxy/src/main/java/org/apache/shenyu/springboot/starter/plugin/ai/proxy/AiProxyPluginConfiguration.java
@@ -22,7 +22,6 @@
import org.apache.shenyu.plugin.ai.common.spring.ai.factory.OpenAiModelFactory;
import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
import org.apache.shenyu.plugin.ai.proxy.enhanced.AiProxyPlugin;
-import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
@@ -48,42 +47,30 @@ public class AiProxyPluginConfiguration {
/**
* Ai proxy plugin.
*
- * @param aiModelFactoryRegistry the aiModelFactoryRegistry
* @param aiProxyConfigService the aiProxyConfigService
* @param aiProxyExecutorService the aiProxyExecutorService
- * @param chatClientCache the chatClientCache
* @param aiProxyPluginHandler the aiProxyPluginHandler
* @return the shenyu plugin
*/
@Bean
public ShenyuPlugin aiProxyPlugin(
- final AiModelFactoryRegistry aiModelFactoryRegistry,
final AiProxyConfigService aiProxyConfigService,
final AiProxyExecutorService aiProxyExecutorService,
- final ChatClientCache chatClientCache,
final AiProxyPluginHandler aiProxyPluginHandler) {
return new AiProxyPlugin(
- aiModelFactoryRegistry,
aiProxyConfigService,
aiProxyExecutorService,
- chatClientCache,
aiProxyPluginHandler);
}
/**
* Ai proxy plugin handler.
*
- * @param chatClientCache the chatClientCache
* @return the shenyu plugin handler
*/
@Bean
- public AiProxyPluginHandler aiProxyPluginHandler(final ChatClientCache chatClientCache) {
- return new AiProxyPluginHandler(chatClientCache);
- }
-
- @Bean
- public ChatClientCache chatClientCache() {
- return new ChatClientCache();
+ public AiProxyPluginHandler aiProxyPluginHandler() {
+ return new AiProxyPluginHandler();
}
@Bean
@@ -131,11 +118,10 @@ public DeepSeekModelFactory deepSeekModelFactory() {
/**
* Ai proxy api key data subscriber.
*
- * @param chatClientCache the chatClientCache
* @return the subscriber
*/
@Bean
- public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber(final ChatClientCache chatClientCache) {
- return new CommonAiProxyApiKeyDataSubscriber(chatClientCache);
+ public AiProxyApiKeyDataSubscriber aiProxyApiKeyDataSubscriber() {
+ return new CommonAiProxyApiKeyDataSubscriber();
}
}