diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index fba44ffd4ce..528d84e9fb6 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -20,14 +20,11 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; @@ -37,15 +34,9 @@ import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.AbstractObservableChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.model.MessageAggregator; -import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.observation.ChatModelObservationContext; -import org.springframework.ai.chat.observation.ChatModelObservationConvention; -import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; -import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.deepseek.api.DeepSeekApi; @@ -61,8 +52,6 @@ import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; -import org.springframework.ai.model.tool.ToolExecutionResult; -import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; @@ -72,17 +61,17 @@ import org.springframework.util.CollectionUtils; /** - * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal DeepSeek} - * backed by {@link DeepSeekApi}. + * DeepSeek chat model implementation backed by {@link DeepSeekApi}. Extends + * {@link AbstractObservableChatModel} to provide observation, retry, and tool calling + * support. * * @author Geng Rong + * @author Fu Jian */ -public class DeepSeekChatModel implements ChatModel { +public class DeepSeekChatModel extends AbstractObservableChatModel { private static final Logger logger = LoggerFactory.getLogger(DeepSeekChatModel.class); - private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); - private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); /** @@ -90,37 +79,11 @@ public class DeepSeekChatModel implements ChatModel { */ private final DeepSeekChatOptions defaultOptions; - /** - * The retry template used to retry the DeepSeek API calls. - */ - public final RetryTemplate retryTemplate; - /** * Low-level access to the DeepSeek API. */ private final DeepSeekApi deepSeekApi; - /** - * Observation registry used for instrumentation. - */ - private final ObservationRegistry observationRegistry; - - /** - * The tool calling manager used to execute tools. - */ - private final ToolCallingManager toolCallingManager; - - /** - * The tool execution eligibility predicate used to determine if a tool can be - * executed. - */ - private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; - - /** - * Conventions to use for generating observations. - */ - private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { @@ -131,198 +94,97 @@ public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions defaultOpt public DeepSeekChatModel(DeepSeekApi deepSeekApi, DeepSeekChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + super(observationRegistry, retryTemplate, toolCallingManager, toolExecutionEligibilityPredicate); Assert.notNull(deepSeekApi, "deepSeekApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); - Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); - Assert.notNull(retryTemplate, "retryTemplate cannot be null"); - Assert.notNull(observationRegistry, "observationRegistry cannot be null"); - Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); this.deepSeekApi = deepSeekApi; this.defaultOptions = defaultOptions; - this.toolCallingManager = toolCallingManager; - this.retryTemplate = retryTemplate; - this.observationRegistry = observationRegistry; - this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } @Override - public ChatResponse call(Prompt prompt) { - Prompt requestPrompt = buildRequestPrompt(prompt); - return this.internalCall(requestPrompt, null); + protected String getProviderName() { + return DeepSeekConstants.PROVIDER_NAME; } - public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { - + @Override + protected ChatResponse doCall(Prompt prompt, ChatResponse previousChatResponse) { ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(DeepSeekConstants.PROVIDER_NAME) - .build(); - - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { - - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.deepSeekApi.chatCompletionEntity(request)); + ResponseEntity completionEntity = this.deepSeekApi.chatCompletionEntity(request); - var chatCompletion = completionEntity.getBody(); + var chatCompletion = completionEntity.getBody(); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - List choices = chatCompletion.choices(); - if (choices == null) { - logger.warn("No choices returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + List choices = chatCompletion.choices(); + if (choices == null) { + logger.warn("No choices returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - List generations = choices.stream().map(choice -> { + List generations = choices.stream().map(choice -> { // @formatter:off - Map metadata = Map.of( - "id", chatCompletion.id() != null ? chatCompletion.id() : "", - "role", choice.message().role() != null ? choice.message().role().name() : "", - "index", choice.index(), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); - // @formatter:on - return buildGeneration(choice, metadata); - }).toList(); - - // Current usage - DeepSeekApi.Usage usage = completionEntity.getBody().usage(); - Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); - Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, - previousChatResponse); - ChatResponse chatResponse = new ChatResponse(generations, - from(completionEntity.getBody(), accumulatedUsage)); - - observationContext.setResponse(chatResponse); - - return chatResponse; - - }); - - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); - if (toolExecutionResult.returnDirect()) { - // Return tool execution result directly to the client. - return ChatResponse.builder() - .from(response) - .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) - .build(); - } - else { - // Send the tool execution result back to the model. - return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); - } - } + Map metadata = Map.of( + "id", chatCompletion.id() != null ? chatCompletion.id() : "", + "role", choice.message().role() != null ? choice.message().role().name() : "", + "index", choice.index(), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); + // @formatter:on + return buildGeneration(choice, metadata); + }).toList(); + + // Current usage + DeepSeekApi.Usage usage = completionEntity.getBody().usage(); + Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); + Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); - return response; + return new ChatResponse(generations, from(completionEntity.getBody(), accumulatedUsage)); } @Override - public Flux stream(Prompt prompt) { - Prompt requestPrompt = buildRequestPrompt(prompt); - return internalStream(requestPrompt, null); - } + protected Flux doStream(Prompt prompt, ChatResponse previousChatResponse) { + ChatCompletionRequest request = createRequest(prompt, true); + + Flux completionChunks = this.deepSeekApi.chatCompletionStream(request); - public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { - return Flux.deferContextual(contextView -> { - ChatCompletionRequest request = createRequest(prompt, true); - - Flux completionChunks = this.deepSeekApi.chatCompletionStream(request); - - // For chunked responses, only the first chunk contains the choice role. - // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); - - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(DeepSeekConstants.PROVIDER_NAME) - .build(); - - Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry); - - observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - - Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) - .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { - try { - String id = chatCompletion2.id(); - - List generations = chatCompletion2.choices().stream().map(choice -> { - if (choice.message().role() != null) { - roleMap.putIfAbsent(id, choice.message().role().name()); - } - - // @formatter:off - Map metadata = Map.of( - "id", chatCompletion2.id(), - "role", roleMap.getOrDefault(id, ""), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" - ); - // @formatter:on - return buildGeneration(choice, metadata); - }).toList(); - DeepSeekApi.Usage usage = chatCompletion2.usage(); - Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage(); - Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); - - return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); - } - catch (Exception e) { - logger.error("Error processing chat completion", e); - return new ChatResponse(List.of()); - } - - })); + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + try { + String id = chatCompletion2.id(); // @formatter:off - Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - return Flux.deferContextual(ctx -> { - ToolExecutionResult toolExecutionResult; - try { - ToolCallReactiveContextHolder.setContext(ctx); - toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); - } - finally { - ToolCallReactiveContextHolder.clearContext(); + List generations = chatCompletion2.choices().stream().map(choice -> { + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); } - if (toolExecutionResult.returnDirect()) { - // Return tool execution result directly to the client. - return Flux.just(ChatResponse.builder().from(response) - .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) - .build()); - } - else { - // Send the tool execution result back to the model. - return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); - } - }).subscribeOn(Schedulers.boundedElastic()); + Map metadata = Map.of( + "id", chatCompletion2.id(), + "role", roleMap.getOrDefault(id, ""), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" + ); + return buildGeneration(choice, metadata); + }).toList(); + // @formatter:on + DeepSeekApi.Usage usage = chatCompletion2.usage(); + Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); + + return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage)); } - else { - return Flux.just(response); + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); } - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - // @formatter:on + })); - return new MessageAggregator().aggregate(flux, observationContext::setResponse); - - }); + return chatResponse; } private Generation buildGeneration(Choice choice, Map metadata) { @@ -390,7 +252,8 @@ private DefaultUsage getDefaultUsage(DeepSeekApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } - Prompt buildRequestPrompt(Prompt prompt) { + @Override + protected Prompt buildRequestPrompt(Prompt prompt) { DeepSeekChatOptions runtimeOptions = null; if (prompt.getOptions() != null) { if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { @@ -495,18 +358,6 @@ private List getFunctionTools(List too }).toList(); } - private ChatOptions buildRequestOptions(DeepSeekApi.ChatCompletionRequest request) { - return ChatOptions.builder() - .model(request.model()) - .frequencyPenalty(request.frequencyPenalty()) - .maxTokens(request.maxTokens()) - .presencePenalty(request.presencePenalty()) - .stopSequences(request.stop()) - .temperature(request.temperature()) - .topP(request.topP()) - .build(); - } - @Override public ChatOptions getDefaultOptions() { return DeepSeekChatOptions.fromOptions(this.defaultOptions); @@ -517,15 +368,6 @@ public String toString() { return "DeepSeekChatModel [defaultOptions=" + this.defaultOptions + "]"; } - /** - * Use the provided convention for reporting observation data - * @param observationConvention The provided convention - */ - public void setObservationConvention(ChatModelObservationConvention observationConvention) { - Assert.notNull(observationConvention, "observationConvention cannot be null"); - this.observationConvention = observationConvention; - } - public static Builder builder() { return new Builder(); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 246b7893c4a..8ee4e956d2b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -24,14 +24,11 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; @@ -43,15 +40,9 @@ import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.metadata.Usage; -import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.AbstractObservableChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.model.MessageAggregator; -import org.springframework.ai.chat.model.StreamingChatModel; -import org.springframework.ai.chat.observation.ChatModelObservationContext; -import org.springframework.ai.chat.observation.ChatModelObservationConvention; -import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; -import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; @@ -60,8 +51,6 @@ import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; -import org.springframework.ai.model.tool.ToolExecutionResult; -import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; @@ -88,8 +77,9 @@ import org.springframework.util.StringUtils; /** - * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} - * backed by {@link OpenAiApi}. + * OpenAI chat model implementation backed by {@link OpenAiApi}. Extends + * {@link AbstractObservableChatModel} to provide observation, retry, and tool calling + * support. * * @author Mark Pollack * @author Christian Tzolov @@ -106,16 +96,14 @@ * @author Alexandros Pappas * @author Soby Chacko * @author Jonghoon Park - * @see ChatModel - * @see StreamingChatModel + * @author Fu Jian + * @see AbstractObservableChatModel * @see OpenAiApi */ -public class OpenAiChatModel implements ChatModel { +public class OpenAiChatModel extends AbstractObservableChatModel { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class); - private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); - private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); /** @@ -123,34 +111,11 @@ public class OpenAiChatModel implements ChatModel { */ private final OpenAiChatOptions defaultOptions; - /** - * The retry template used to retry the OpenAI API calls. - */ - private final RetryTemplate retryTemplate; - /** * Low-level access to the OpenAI API. */ private final OpenAiApi openAiApi; - /** - * Observation registry used for instrumentation. - */ - private final ObservationRegistry observationRegistry; - - private final ToolCallingManager toolCallingManager; - - /** - * The tool execution eligibility predicate used to determine if a tool can be - * executed. - */ - private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; - - /** - * Conventions to use for generating observations. - */ - private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { this(openAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, @@ -160,246 +125,143 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, To public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + super(observationRegistry, retryTemplate, toolCallingManager, toolExecutionEligibilityPredicate); Assert.notNull(openAiApi, "openAiApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); - Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); - Assert.notNull(retryTemplate, "retryTemplate cannot be null"); - Assert.notNull(observationRegistry, "observationRegistry cannot be null"); - Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); this.openAiApi = openAiApi; this.defaultOptions = defaultOptions; - this.toolCallingManager = toolCallingManager; - this.retryTemplate = retryTemplate; - this.observationRegistry = observationRegistry; - this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } @Override - public ChatResponse call(Prompt prompt) { - // Before moving any further, build the final request Prompt, - // merging runtime and default options. - Prompt requestPrompt = buildRequestPrompt(prompt); - return this.internalCall(requestPrompt, null); + protected String getProviderName() { + return OpenAiApiConstants.PROVIDER_NAME; } - public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { - + @Override + protected ChatResponse doCall(Prompt prompt, ChatResponse previousChatResponse) { ChatCompletionRequest request = createRequest(prompt, false); - ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(OpenAiApiConstants.PROVIDER_NAME) - .build(); - - ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION - .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry) - .observe(() -> { - - ResponseEntity completionEntity = this.retryTemplate - .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); + ResponseEntity completionEntity = this.openAiApi.chatCompletionEntity(request, + getAdditionalHttpHeaders(prompt)); - var chatCompletion = completionEntity.getBody(); + var chatCompletion = completionEntity.getBody(); - if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + if (chatCompletion == null) { + logger.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - List choices = chatCompletion.choices(); - if (choices == null) { - logger.warn("No choices returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + List choices = chatCompletion.choices(); + if (choices == null) { + logger.warn("No choices returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } - // @formatter:off - List generations = choices.stream().map(choice -> { - Map metadata = Map.of( - "id", chatCompletion.id() != null ? chatCompletion.id() : "", - "role", choice.message().role() != null ? choice.message().role().name() : "", - "index", choice.index() != null ? choice.index() : 0, - "finishReason", getFinishReasonJson(choice.finishReason()), - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", - "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of())); - return buildGeneration(choice, metadata, request); - }).toList(); - // @formatter:on - - RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); - - // Current usage - OpenAiApi.Usage usage = chatCompletion.usage(); - Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); - Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, - previousChatResponse); - ChatResponse chatResponse = new ChatResponse(generations, - from(chatCompletion, rateLimit, accumulatedUsage)); - - observationContext.setResponse(chatResponse); - - return chatResponse; + // @formatter:off + List generations = choices.stream().map(choice -> { + Map metadata = Map.of( + "id", chatCompletion.id() != null ? chatCompletion.id() : "", + "role", choice.message().role() != null ? choice.message().role().name() : "", + "index", choice.index() != null ? choice.index() : 0, + "finishReason", getFinishReasonJson(choice.finishReason()), + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of())); + return buildGeneration(choice, metadata, request); + }).toList(); + // @formatter:on - }); + RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { - var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); - if (toolExecutionResult.returnDirect()) { - // Return tool execution result directly to the client. - return ChatResponse.builder() - .from(response) - .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) - .build(); - } - else { - // Send the tool execution result back to the model. - return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); - } - } + // Current usage + OpenAiApi.Usage usage = chatCompletion.usage(); + Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); + Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); - return response; + return new ChatResponse(generations, from(chatCompletion, rateLimit, accumulatedUsage)); } @Override - public Flux stream(Prompt prompt) { - // Before moving any further, build the final request Prompt, - // merging runtime and default options. - Prompt requestPrompt = buildRequestPrompt(prompt); - return internalStream(requestPrompt, null); - } - - public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { - return Flux.deferContextual(contextView -> { - ChatCompletionRequest request = createRequest(prompt, true); - - if (request.outputModalities() != null - && request.outputModalities().contains(OpenAiApi.OutputModality.AUDIO)) { - logger.warn("Audio output is not supported for streaming requests. Removing audio output."); - throw new IllegalArgumentException("Audio output is not supported for streaming requests."); - } - - if (request.audioParameters() != null) { - logger.warn("Audio parameters are not supported for streaming requests. Removing audio parameters."); - throw new IllegalArgumentException("Audio parameters are not supported for streaming requests."); - } - - Flux completionChunks = this.openAiApi.chatCompletionStream(request, - getAdditionalHttpHeaders(prompt)); + protected Flux doStream(Prompt prompt, ChatResponse previousChatResponse) { + ChatCompletionRequest request = createRequest(prompt, true); - // For chunked responses, only the first chunk contains the choice role. - // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + if (request.outputModalities() != null && request.outputModalities().contains(OpenAiApi.OutputModality.AUDIO)) { + logger.warn("Audio output is not supported for streaming requests. Removing audio output."); + throw new IllegalArgumentException("Audio output is not supported for streaming requests."); + } - final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) - .provider(OpenAiApiConstants.PROVIDER_NAME) - .build(); + if (request.audioParameters() != null) { + logger.warn("Audio parameters are not supported for streaming requests. Removing audio parameters."); + throw new IllegalArgumentException("Audio parameters are not supported for streaming requests."); + } - Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - this.observationRegistry); + Flux completionChunks = this.openAiApi.chatCompletionStream(request, + getAdditionalHttpHeaders(prompt)); - observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); - // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse - // the function call handling logic. - Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) - .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { - try { - // If an id is not provided, set to "NO_ID" (for compatible APIs). - String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id(); + // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse + // the function call handling logic. + Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + try { + // If an id is not provided, set to "NO_ID" (for compatible APIs). + String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id(); - List generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off - if (choice.message().role() != null) { - roleMap.putIfAbsent(id, choice.message().role().name()); - } - Map metadata = Map.of( - "id", id, - "role", roleMap.getOrDefault(id, ""), - "index", choice.index() != null ? choice.index() : 0, - "finishReason", getFinishReasonJson(choice.finishReason()), - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", - "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); - return buildGeneration(choice, metadata, request); - }).toList(); - // @formatter:on - OpenAiApi.Usage usage = chatCompletion2.usage(); - Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); - Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, - previousChatResponse); - return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage)); - } - catch (Exception e) { - logger.error("Error processing chat completion", e); - return new ChatResponse(List.of()); - } - // When in stream mode and enabled to include the usage, the OpenAI - // Chat completion response would have the usage set only in its - // final response. Hence, the following overlapping buffer is - // created to store both the current and the subsequent response - // to accumulate the usage from the subsequent response. - })) - .buffer(2, 1) - .map(bufferList -> { - ChatResponse firstResponse = bufferList.get(0); - if (request.streamOptions() != null && request.streamOptions().includeUsage()) { - if (bufferList.size() == 2) { - ChatResponse secondResponse = bufferList.get(1); - if (secondResponse != null && secondResponse.getMetadata() != null) { - // This is the usage from the final Chat response for a - // given Chat request. - Usage usage = secondResponse.getMetadata().getUsage(); - if (!UsageCalculator.isEmpty(usage)) { - // Store the usage from the final response to the - // penultimate response for accumulation. - return new ChatResponse(firstResponse.getResults(), - from(firstResponse.getMetadata(), usage)); - } + List generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + Map metadata = Map.of( + "id", id, + "role", roleMap.getOrDefault(id, ""), + "index", choice.index() != null ? choice.index() : 0, + "finishReason", getFinishReasonJson(choice.finishReason()), + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); + return buildGeneration(choice, metadata, request); + }).toList(); + // @formatter:on + OpenAiApi.Usage usage = chatCompletion2.usage(); + Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); + Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, + previousChatResponse); + return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage)); + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + // When in stream mode and enabled to include the usage, the OpenAI + // Chat completion response would have the usage set only in its + // final response. Hence, the following overlapping buffer is + // created to store both the current and the subsequent response + // to accumulate the usage from the subsequent response. + })) + .buffer(2, 1) + .map(bufferList -> { + ChatResponse firstResponse = bufferList.get(0); + if (request.streamOptions() != null && request.streamOptions().includeUsage()) { + if (bufferList.size() == 2) { + ChatResponse secondResponse = bufferList.get(1); + if (secondResponse != null && secondResponse.getMetadata() != null) { + // This is the usage from the final Chat response for a + // given Chat request. + Usage usage = secondResponse.getMetadata().getUsage(); + if (!UsageCalculator.isEmpty(usage)) { + // Store the usage from the final response to the + // penultimate response for accumulation. + return new ChatResponse(firstResponse.getResults(), + from(firstResponse.getMetadata(), usage)); } } } - return firstResponse; - }); - - // @formatter:off - Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - return Flux.deferContextual(ctx -> { - ToolExecutionResult toolExecutionResult; - try { - ToolCallReactiveContextHolder.setContext(ctx); - toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); - } - finally { - ToolCallReactiveContextHolder.clearContext(); - } - if (toolExecutionResult.returnDirect()) { - // Return tool execution result directly to the client. - return Flux.just(ChatResponse.builder().from(response) - .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) - .build()); - } - else { - // Send the tool execution result back to the model. - return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); - } - }).subscribeOn(Schedulers.boundedElastic()); - } - else { - return Flux.just(response); } - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - // @formatter:on - - return new MessageAggregator().aggregate(flux, observationContext::setResponse); + return firstResponse; + }); - }); + return chatResponse; } private MultiValueMap getAdditionalHttpHeaders(Prompt prompt) { @@ -511,7 +373,8 @@ private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } - Prompt buildRequestPrompt(Prompt prompt) { + @Override + protected Prompt buildRequestPrompt(Prompt prompt) { // Process runtime options OpenAiChatOptions runtimeOptions = null; if (prompt.getOptions() != null) { @@ -707,15 +570,6 @@ public String toString() { return "OpenAiChatModel [defaultOptions=" + this.defaultOptions + "]"; } - /** - * Use the provided convention for reporting observation data - * @param observationConvention The provided convention - */ - public void setObservationConvention(ChatModelObservationConvention observationConvention) { - Assert.notNull(observationConvention, "observationConvention cannot be null"); - this.observationConvention = observationConvention; - } - public static Builder builder() { return new Builder(); } diff --git a/spring-ai-model/pom.xml b/spring-ai-model/pom.xml index 70874f2d865..34964a5de54 100644 --- a/spring-ai-model/pom.xml +++ b/spring-ai-model/pom.xml @@ -69,12 +69,17 @@ spring-messaging - - io.projectreactor - reactor-core - + + io.projectreactor + reactor-core + + + + org.springframework.retry + spring-retry + - + org.antlr antlr4-runtime diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/AbstractObservableChatModel.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/AbstractObservableChatModel.java new file mode 100644 index 00000000000..1a15012fa4a --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/AbstractObservableChatModel.java @@ -0,0 +1,272 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.model; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +/** + * Abstract base class for ChatModel implementations that provides common functionality + * for observation, retry, and tool calling orchestration. + * + *

+ * Subclasses need to: + *

    + *
  • Initialize the protected fields in their constructors
  • + *
  • Implement {@link #getProviderName()} to return the provider identifier
  • + *
  • Implement {@link #doCall(Prompt, ChatResponse)} for the actual API interaction
  • + *
  • Implement {@link #doStream(Prompt, ChatResponse)} for streaming API + * interaction
  • + *
  • Implement {@link #buildRequestPrompt(Prompt prompt)} for builds the final request + * prompt by merging options
  • + *
+ * + * @author Fu Jian + * @since 1.1.0 + */ +public abstract class AbstractObservableChatModel implements ChatModel, StreamingChatModel { + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + protected ObservationRegistry observationRegistry; + + protected RetryTemplate retryTemplate; + + protected ToolCallingManager toolCallingManager; + + protected ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + + protected ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + /** + * Constructor for AbstractObservableChatModel. + * @param observationRegistry the observation registry + * @param retryTemplate the retry template + * @param toolCallingManager the tool calling manager + * @param toolExecutionEligibilityPredicate the tool execution eligibility predicate + */ + protected AbstractObservableChatModel(ObservationRegistry observationRegistry, RetryTemplate retryTemplate, + ToolCallingManager toolCallingManager, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(retryTemplate, "retryTemplate cannot be null"); + Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); + + this.observationRegistry = observationRegistry; + this.retryTemplate = retryTemplate; + this.toolCallingManager = toolCallingManager; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + } + + /** + * Sets the observation convention for this chat model. + * @param observationConvention the observation convention to use + */ + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + this.observationConvention = observationConvention; + } + + @Override + public final ChatResponse call(Prompt prompt) { + Prompt requestPrompt = buildRequestPrompt(prompt); + return call(requestPrompt, null); + } + + /** + * Internal call method that handles observation, retry, and tool calling + * orchestration. This method can be called recursively for tool execution. + * @param prompt the prompt to process + * @param previousChatResponse the previous chat response + * @return the final chat response + */ + protected ChatResponse call(Prompt prompt, ChatResponse previousChatResponse) { + ChatModelObservationContext observationContext = createObservationContext(prompt); + + ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ChatResponse chatResponse = this.retryTemplate.execute(ctx -> doCall(prompt, previousChatResponse)); + if (observationContext != null) { + observationContext.setResponse(chatResponse); + } + return chatResponse; + }); + + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response) + && shouldExecuteTools(prompt, response)) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return call(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); + } + } + + return response; + } + + /** + * Creates the observation context for the given prompt. Subclasses can override this + * to customize the observation context. + * @param prompt the prompt + * @return the observation context + */ + protected ChatModelObservationContext createObservationContext(Prompt prompt) { + return ChatModelObservationContext.builder().prompt(prompt).provider(getProviderName()).build(); + } + + /** + * Returns the provider name for observation context. + * @return the provider name (e.g., "openai", "anthropic") + */ + protected abstract String getProviderName(); + + /** + * Performs the actual chat completion API call. Subclasses should implement their + * provider-specific API interaction here, without worrying about observation, retry, + * or tool calling orchestration. + * @param prompt the prompt to send to the model + * @param previousChatResponse the previous chat response tracking, or null if this is + * the first call + * @return the chat response from the model + */ + protected abstract ChatResponse doCall(Prompt prompt, ChatResponse previousChatResponse); + + /** + * Additional condition check for tool execution. Subclasses can override this to add + * provider-specific conditions (e.g., checking finish reasons). + * @param prompt the prompt + * @param response the response + * @return true if tools should be executed + */ + protected boolean shouldExecuteTools(Prompt prompt, ChatResponse response) { + return true; + } + + @Override + public Flux stream(Prompt prompt) { + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return stream(requestPrompt, null); + } + + /** + * Builds the final request prompt by merging runtime options with default options. + * Subclasses should implement this method to handle provider-specific option merging + * and validation logic. + * @param prompt the original prompt with runtime options + * @return the final prompt with merged options ready for API call + */ + protected abstract Prompt buildRequestPrompt(Prompt prompt); + + /** + * Internal stream method that handles observation and tool calling orchestration for + * streaming responses. This method can be called recursively for tool execution. + * @param prompt the prompt to process + * @param previousChatResponse the previous chat response + * @return a flux of chat responses + */ + protected Flux stream(Prompt prompt, ChatResponse previousChatResponse) { + return Flux.deferContextual(contextView -> { + ChatModelObservationContext observationContext = createObservationContext(prompt); + + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + Flux chatResponseFlux = this.retryTemplate + .execute(ctx -> doStream(prompt, previousChatResponse)); + + Flux flux = chatResponseFlux.flatMap(response -> { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response) + && shouldExecuteTools(prompt, response)) { + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual(ctx -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } + finally { + ToolCallReactiveContextHolder.clearContext(); + } + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } + else { + // Send the tool execution result back to the model. + return stream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } + }).subscribeOn(Schedulers.boundedElastic()); + } + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + + return new MessageAggregator().aggregate(flux, observationContext::setResponse); + }); + } + + /** + * Performs the actual streaming chat completion API call. Subclasses should implement + * their provider-specific streaming API interaction here, without worrying about + * observation or tool calling orchestration. Note that retry is handled outside the + * stream for streaming calls. + * @param prompt the prompt to send to the model + * @param previousChatResponse the previous chat response tracking, or null if this is + * the first call + * @return a flux of chat responses from the model + */ + protected abstract Flux doStream(Prompt prompt, ChatResponse previousChatResponse); + +}