Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationHandler;
Expand Down Expand Up @@ -84,11 +85,12 @@ ChatClientBuilderConfigurer chatClientBuilderConfigurer(ObjectProvider<ChatClien
@ConditionalOnMissingBean
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatClientObservationConvention> observationConvention) {

ObjectProvider<ChatClientObservationConvention> chatClientObservationConvention,
ObjectProvider<AdvisorObservationConvention> advisorObservationConvention) {
ChatClient.Builder builder = ChatClient.builder(chatModel,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
observationConvention.getIfUnique(() -> null));
chatClientObservationConvention.getIfUnique(() -> null),
advisorObservationConvention.getIfUnique(() -> null));
return chatClientBuilderConfigurer.configure(builder);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
Expand Down Expand Up @@ -65,22 +66,46 @@ static ChatClient create(ChatModel chatModel, ObservationRegistry observationReg
return create(chatModel, observationRegistry, null);
}

/**
* @deprecated in favor of
* {@link #create(ChatModel, ObservationRegistry, ChatClientObservationConvention, AdvisorObservationConvention)}.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention) {
return create(chatModel, observationRegistry, chatClientObservationConvention, null);
}

static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention observationConvention) {
@Nullable ChatClientObservationConvention chatClientObservationConvention,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
return builder(chatModel, observationRegistry, observationConvention).build();
return builder(chatModel, observationRegistry, chatClientObservationConvention, advisorObservationConvention)
.build();
}

static Builder builder(ChatModel chatModel) {
return builder(chatModel, ObservationRegistry.NOOP, null);
}

/**
* @deprecated in favor of
* {@link #builder(ChatModel, ObservationRegistry, ChatClientObservationConvention, AdvisorObservationConvention)}.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention) {
return builder(chatModel, observationRegistry, chatClientObservationConvention, null);
}

static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention customObservationConvention) {
@Nullable ChatClientObservationConvention chatClientObservationConvention,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
return new DefaultChatClientBuilder(chatModel, observationRegistry, customObservationConvention);
return new DefaultChatClientBuilder(chatModel, observationRegistry, chatClientObservationConvention,
advisorObservationConvention);
}

ChatClientRequestSpec prompt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
Expand Down Expand Up @@ -513,7 +514,9 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c
// CHECKSTYLE:OFF
var chatClientResponse = observation.observe(() -> {
// Apply the advisor chain that terminates with the ChatModelCallAdvisor.
return this.advisorChain.nextCall(chatClientRequest);
var response = this.advisorChain.nextCall(chatClientRequest);
observationContext.setChatClientResponse(response);
return response;
});
// CHECKSTYLE:ON
return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build();
Expand Down Expand Up @@ -608,7 +611,10 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe

private final ObservationRegistry observationRegistry;

private final ChatClientObservationConvention observationConvention;
private final ChatClientObservationConvention chatClientObservationConvention;

@Nullable
private final AdvisorObservationConvention advisorObservationConvention;

private final ChatModel chatModel;

Expand Down Expand Up @@ -649,18 +655,34 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams,
ccr.systemMetadata, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions,
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention,
ccr.toolContext, ccr.templateRenderer);
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.chatClientObservationConvention,
ccr.toolContext, ccr.templateRenderer, ccr.advisorObservationConvention);
}

/**
* @deprecated in favor of the other constructor.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, Map<String, Object> userMetadata, @Nullable String systemText,
Map<String, Object> systemParams, Map<String, Object> systemMetadata, List<ToolCallback> toolCallbacks,
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
@Nullable TemplateRenderer templateRenderer) {
@Nullable ChatClientObservationConvention chatClientObservationConvention,
Map<String, Object> toolContext, @Nullable TemplateRenderer templateRenderer) {
this(chatModel, userText, userParams, userMetadata, systemText, systemParams, systemMetadata, toolCallbacks,
messages, toolNames, media, chatOptions, advisors, advisorParams, observationRegistry,
chatClientObservationConvention, toolContext, templateRenderer, null);
}

public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, Map<String, Object> userMetadata, @Nullable String systemText,
Map<String, Object> systemParams, Map<String, Object> systemMetadata, List<ToolCallback> toolCallbacks,
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention,
Map<String, Object> toolContext, @Nullable TemplateRenderer templateRenderer,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(userParams, "userParams cannot be null");
Assert.notNull(userMetadata, "userMetadata cannot be null");
Expand Down Expand Up @@ -694,10 +716,11 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.observationRegistry = observationRegistry;
this.observationConvention = observationConvention != null ? observationConvention
: DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
this.chatClientObservationConvention = chatClientObservationConvention != null
? chatClientObservationConvention : DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
this.toolContext.putAll(toolContext);
this.templateRenderer = templateRenderer != null ? templateRenderer : DEFAULT_TEMPLATE_RENDERER;
this.advisorObservationConvention = advisorObservationConvention;
}

@Nullable
Expand Down Expand Up @@ -770,7 +793,8 @@ public TemplateRenderer getTemplateRenderer() {
@Override
public Builder mutate() {
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
.builder(this.chatModel, this.observationRegistry, this.chatClientObservationConvention,
this.advisorObservationConvention)
.defaultTemplateRenderer(this.templateRenderer)
.defaultToolCallbacks(this.toolCallbacks)
.defaultToolContext(this.toolContext)
Expand Down Expand Up @@ -990,14 +1014,14 @@ public ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer)
public CallResponseSpec call() {
BaseAdvisorChain advisorChain = buildAdvisorChain();
return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
this.observationRegistry, this.observationConvention);
this.observationRegistry, this.chatClientObservationConvention);
}

@Override
public StreamResponseSpec stream() {
BaseAdvisorChain advisorChain = buildAdvisorChain();
return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
this.observationRegistry, this.observationConvention);
this.observationRegistry, this.chatClientObservationConvention);
}

private BaseAdvisorChain buildAdvisorChain() {
Expand All @@ -1006,7 +1030,10 @@ private BaseAdvisorChain buildAdvisorChain() {
this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build());
this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build());

return DefaultAroundAdvisorChain.builder(this.observationRegistry).pushAll(this.advisors).build();
return DefaultAroundAdvisorChain.builder(this.observationRegistry)
.observationConvention(this.advisorObservationConvention)
.pushAll(this.advisors)
.build();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.springframework.ai.chat.client.ChatClient.PromptUserSpec;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
Expand Down Expand Up @@ -60,13 +61,24 @@ public class DefaultChatClientBuilder implements Builder {
this(chatModel, ObservationRegistry.NOOP, null);
}

/**
* @deprecated in favor of
* {@link #DefaultChatClientBuilder(ChatModel, ObservationRegistry, ChatClientObservationConvention, AdvisorObservationConvention)}.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention customObservationConvention) {
@Nullable ChatClientObservationConvention chatClientObservationConvention) {
this(chatModel, observationRegistry, chatClientObservationConvention, null);
}

public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(),
Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
customObservationConvention, Map.of(), null);
chatClientObservationConvention, Map.of(), null, advisorObservationConvention);
}

public ChatClient build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation;
import org.springframework.ai.chat.client.advisor.observation.DefaultAdvisorObservationConvention;
import org.springframework.ai.template.TemplateRenderer;
import org.springframework.ai.template.st.StTemplateRenderer;
import org.springframework.core.OrderComparator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand All @@ -55,8 +54,6 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {

public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();

private static final TemplateRenderer DEFAULT_TEMPLATE_RENDERER = StTemplateRenderer.builder().build();

private final List<CallAdvisor> originalCallAdvisors;

private final List<StreamAdvisor> originalStreamAdvisors;
Expand All @@ -67,8 +64,10 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {

private final ObservationRegistry observationRegistry;

private final AdvisorObservationConvention observationConvention;

DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, Deque<CallAdvisor> callAdvisors,
Deque<StreamAdvisor> streamAdvisors) {
Deque<StreamAdvisor> streamAdvisors, @Nullable AdvisorObservationConvention observationConvention) {

Assert.notNull(observationRegistry, "the observationRegistry must be non-null");
Assert.notNull(callAdvisors, "the callAdvisors must be non-null");
Expand All @@ -79,6 +78,8 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
this.streamAdvisors = streamAdvisors;
this.originalCallAdvisors = List.copyOf(callAdvisors);
this.originalStreamAdvisors = List.copyOf(streamAdvisors);
this.observationConvention = observationConvention != null ? observationConvention
: DEFAULT_OBSERVATION_CONVENTION;
}

public static Builder builder(ObservationRegistry observationRegistry) {
Expand All @@ -102,8 +103,13 @@ public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) {
.build();

return AdvisorObservationDocumentation.AI_ADVISOR
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
.observe(() -> advisor.adviseCall(chatClientRequest, this));
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
var chatClientResponse = advisor.adviseCall(chatClientRequest, this);
observationContext.setChatClientResponse(chatClientResponse);
return chatClientResponse;
});
}

@Override
Expand All @@ -123,7 +129,7 @@ public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest)
.order(advisor.getOrder())
.build();

var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null,
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(this.observationConvention,
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
Expand Down Expand Up @@ -160,12 +166,20 @@ public static class Builder {

private final Deque<StreamAdvisor> streamAdvisors;

@Nullable
private AdvisorObservationConvention observationConvention;

public Builder(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
this.callAdvisors = new ConcurrentLinkedDeque<>();
this.streamAdvisors = new ConcurrentLinkedDeque<>();
}

public Builder observationConvention(@Nullable AdvisorObservationConvention observationConvention) {
this.observationConvention = observationConvention;
return this;
}

public Builder push(Advisor advisor) {
Assert.notNull(advisor, "the advisor must be non-null");
return this.pushAll(List.of(advisor));
Expand Down Expand Up @@ -214,7 +228,8 @@ private void reOrder() {
}

public DefaultAroundAdvisorChain build() {
return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAdvisors, this.streamAdvisors);
return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAdvisors, this.streamAdvisors,
this.observationConvention);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.observation.AiOperationMetadata;
import org.springframework.ai.observation.conventions.AiOperationType;
Expand All @@ -48,6 +49,9 @@ public class ChatClientObservationContext extends Observation.Context {

private final boolean stream;

@Nullable
private ChatClientResponse chatClientResponse;

ChatClientObservationContext(ChatClientRequest chatClientRequest, List<? extends Advisor> advisors,
boolean isStream) {
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
Expand Down Expand Up @@ -78,6 +82,15 @@ public boolean isStream() {
return this.stream;
}

@Nullable
public ChatClientResponse getChatClientResponse() {
return this.chatClientResponse;
}

public void setChatClientResponse(@Nullable ChatClientResponse chatClientResponse) {
this.chatClientResponse = chatClientResponse;
}

@Nullable
public String getFormat() {
if (this.request.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey()) instanceof String format) {
Expand Down
Loading