Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions spring-ai-client-chat/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-json-jackson2</artifactId>
<version>${mcp.sdk.version}</version>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-jsonSchema</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@

package org.springframework.ai.chat.client.advisor;

import java.util.List;
import java.util.function.Predicate;

import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* Utilities to work with advisors.
*
* @author Christian Tzolov
*/
public final class AdvisorUtils {

Expand All @@ -46,4 +52,31 @@ public static Predicate<ChatClientResponse> onFinishReason() {
};
}

/**
* Creates a new CallAdvisorChain copy that contains all advisors after the specified
* advisor.
* @param callAdvisorChain the original CallAdvisorChain
* @param after the CallAdvisor after which to copy the chain
* @return a new CallAdvisorChain containing all advisors after the specified advisor
* @throws IllegalArgumentException if the specified advisor is not part of the chain
*/
public static CallAdvisorChain copyChainAfterAdvisor(CallAdvisorChain callAdvisorChain, CallAdvisor after) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a method on CallAdvisorChain instead? I feel like it would be more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have a point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will refactor it


Assert.notNull(callAdvisorChain, "callAdvisorChain must not be null");
Assert.notNull(after, "The after call advisor must not be null");

List<CallAdvisor> callAdvisors = callAdvisorChain.getCallAdvisors();
int afterAdvisorIndex = callAdvisors.indexOf(after);

if (afterAdvisorIndex < 0) {
throw new IllegalArgumentException("The specified advisor is not part of the chain: " + after.getName());
}

var remainingCallAdvisors = callAdvisors.subList(afterAdvisorIndex + 1, callAdvisors.size());

return DefaultAroundAdvisorChain.builder(callAdvisorChain.getObservationRegistry())
.pushAll(remainingCallAdvisors)
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
return new ChatClientMessageAggregator().aggregateChatClientResponse(chatClientResponses, this::logResponse);
}

private void logRequest(ChatClientRequest request) {
protected void logRequest(ChatClientRequest request) {
logger.debug("request: {}", this.requestToString.apply(request));
}

private void logResponse(ChatClientResponse chatClientResponse) {
protected void logResponse(ChatClientResponse chatClientResponse) {
logger.debug("response: {}", this.responseToString.apply(chatClientResponse.chatResponse()));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client.advisor;

import java.lang.reflect.Type;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper;
import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse;
import io.modelcontextprotocol.json.schema.jackson.DefaultJsonSchemaValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.util.Assert;

/**
* Recursive Advisor that validates the structured JSON output of a chat client entity
* response against a generated JSON schema for the expected output type.
* <p>
* If the validation fails, the advisor will repeat the call up to a specified number of
* attempts.
* <p>
* Note: This advisor does not support streaming responses and will throw an
* UnsupportedOperationException if used in a streaming context.
*
* @author Christian Tzolov
*/
public final class StructuredOutputValidationAdvisor implements CallAdvisor, StreamAdvisor {

private static final Logger logger = LoggerFactory.getLogger(StructuredOutputValidationAdvisor.class);

private static final TypeRef<HashMap<String, Object>> MAP_TYPE_REF = new TypeRef<>() {
};

/**
* Set the order close to Ordered.LOWEST_PRECEDENCE to ensure an advisor is executed
* toward the last (but before the model call) in the chain (last for request
* processing, first for response processing).
*
* https://docs.spring.io/spring-ai/reference/api/advisors.html#_advisor_order
*/
private final int advisorOrder;

/**
* The JSON schema used for validation.
*/
private final Map<String, Object> jsonSchema;

/**
* The JSON schema validator.
*/
private final DefaultJsonSchemaValidator jsonvalidator;

private final int maxRepeatAttempts;

private StructuredOutputValidationAdvisor(int advisorOrder, Type outputType, int repeatAttempts) {
Assert.notNull(advisorOrder, "advisorOrder must not be null");
Assert.notNull(outputType, "outputType must not be null");
Assert.isTrue(advisorOrder > BaseAdvisor.HIGHEST_PRECEDENCE && advisorOrder < BaseAdvisor.LOWEST_PRECEDENCE,
"advisorOrder must be between HIGHEST_PRECEDENCE and LOWEST_PRECEDENCE");
Assert.isTrue(repeatAttempts >= 0, "repeatAttempts must be greater than or equal to 0");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a value of zero effectively makes this infinite, given the do {} while() construct. Is this what we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the repeatAttempts is set to 0 this effectively mean never repeat.
E.g. the loop is count = 0; do { count++ ...} while (count <= maxAttempt); and if the maxAttempt is 0 the loop should run only once?


this.advisorOrder = advisorOrder;

this.jsonvalidator = new DefaultJsonSchemaValidator();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be initialised above, and maybe even made static AFAICT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I wonted to make the objectmaper configurable and aligned with the default mapper used in spring ai (JsonParser.getObjectMapper()) but forgot to add it.
Will update the code


String jsonSchemaText = JsonSchemaGenerator.generateForType(outputType);

logger.info("Generated JSON Schema:\n" + jsonSchemaText);

var jsonMapper = new JacksonMcpJsonMapper(JsonParser.getObjectMapper());

try {
this.jsonSchema = jsonMapper.readValue(jsonSchemaText, MAP_TYPE_REF);
}
catch (Exception e) {
throw new IllegalArgumentException("Failed to parse JSON schema", e);
}

this.maxRepeatAttempts = repeatAttempts;
}

@SuppressWarnings("null")
@Override
public String getName() {
return "Structured Output Validation Advisor";
}

@Override
public int getOrder() {
return this.advisorOrder;
}

@SuppressWarnings("null")
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
Assert.notNull(callAdvisorChain, "callAdvisorChain must not be null");
Assert.notNull(chatClientRequest, "chatClientRequest must not be null");

ChatClientResponse chatClientResponse = null;

var repeatCounter = new AtomicInteger(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not a regular int ???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. Not sure why i thought that some concurrency can occur. will make it int


boolean isValidationSuccess = true;

var processedChatClientRequest = chatClientRequest;

do {
// Before Call
repeatCounter.incrementAndGet();

// Next Call
chatClientResponse = AdvisorUtils.copyChainAfterAdvisor(callAdvisorChain, this)
.nextCall(processedChatClientRequest);

// After Call

// We should not validate tool call requests, only the content of the final
// response.
if (chatClientResponse.chatResponse() == null || !chatClientResponse.chatResponse().hasToolCalls()) {

ValidationResponse validationResponse = this.validateOutputSchema(chatClientResponse);

isValidationSuccess = validationResponse.valid();

if (!isValidationSuccess) {

// Add the validation error message to the next user message
// to let the LLM fix its output.
// Note: We could also consider adding the previous invalid output.
// However, this might lead to confusion and more complex prompts.
// Instead, we rely on the LLM to generate a new output based on the
// validation error.
logger.warn("JSON validation failed: " + validationResponse);

String validationErrorMessage = "Output JSON validation failed because of: "
+ validationResponse.errorMessage();

Prompt augmentedPrompt = chatClientRequest.prompt()
.augmentUserMessage(userMessage -> userMessage.mutate()
.text(userMessage.getText() + System.lineSeparator() + validationErrorMessage)
.build());

processedChatClientRequest = chatClientRequest.mutate().prompt(augmentedPrompt).build();
}
}
}
while (!isValidationSuccess && repeatCounter.get() <= this.maxRepeatAttempts);

return chatClientResponse;
}

@SuppressWarnings("null")
private ValidationResponse validateOutputSchema(ChatClientResponse chatClientResponse) {

if (chatClientResponse.chatResponse() == null || chatClientResponse.chatResponse().getResult() == null
|| chatClientResponse.chatResponse().getResult().getOutput() == null
|| chatClientResponse.chatResponse().getResult().getOutput().getText() == null) {

logger.warn("ChatClientResponse is missing required json output for validation.");
return ValidationResponse.asInvalid("Missing required json output for validation.");
}

// TODO: should we consider validation for multiple results?
String json = chatClientResponse.chatResponse().getResult().getOutput().getText();

logger.debug("Validating JSON output against schema. Attempts left: " + this.maxRepeatAttempts);

return this.jsonvalidator.validate(this.jsonSchema, json);
}

@SuppressWarnings("null")
@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {

return Flux.error(new UnsupportedOperationException(
"The Structured Output Validation Advisor does not support streaming."));
}

/**
* Creates a new Builder for StructuredOutputValidationAdvisor.
* @return a new Builder instance
*/
public static Builder builder() {
return new Builder();
}

/**
* Builder class for StructuredOutputValidationAdvisor.
*/
public final static class Builder {

/**
* Set the order close to Ordered.LOWEST_PRECEDENCE to ensure an advisor is
* executed toward the last (but before the model call) in the chain (last for
* request processing, first for response processing).
*
* https://docs.spring.io/spring-ai/reference/api/advisors.html#_advisor_order
*/
private int advisorOrder = BaseAdvisor.LOWEST_PRECEDENCE - 2000;

private Type outputType;

private int maxRepeatAttempts = 3;

private Builder() {
}

/**
* Sets the advisor order.
* @param advisorOrder the advisor order
* @return this builder
*/
public Builder advisorOrder(int advisorOrder) {
this.advisorOrder = advisorOrder;
return this;
}

/**
* Sets the output type using a Type.
* @param outputType the output type
* @return this builder
*/
public Builder outputType(Type outputType) {
this.outputType = outputType;
return this;
}

/**
* Sets the output type using a TypeRef.
* @param <T> the type parameter
* @param outputType the output type
* @return this builder
*/
public <T> Builder outputType(TypeRef<T> outputType) {
this.outputType = outputType.getType();
return this;
}

/**
* Sets the output type using a TypeReference.
* @param <T> the type parameter
* @param outputType the output type
* @return this builder
*/
public <T> Builder outputType(TypeReference<T> outputType) {
this.outputType = outputType.getType();
return this;
}

/**
* Sets the output type using a ParameterizedTypeReference.
* @param <T> the type parameter
* @param outputType the output type
* @return this builder
*/
public <T> Builder outputType(ParameterizedTypeReference<T> outputType) {
this.outputType = outputType.getType();
return this;
}

/**
* Sets the number of repeat attempts.
* @param repeatAttempts the number of repeat attempts
* @return this builder
*/
public Builder maxRepeatAttempts(int repeatAttempts) {
this.maxRepeatAttempts = repeatAttempts;
return this;
}

/**
* Builds the StructuredOutputValidationAdvisor.
* @return a new StructuredOutputValidationAdvisor instance
* @throws IllegalArgumentException if outputType is not set
*/
public StructuredOutputValidationAdvisor build() {
if (this.outputType == null) {
throw new IllegalArgumentException("outputType must be set");
}
return new StructuredOutputValidationAdvisor(this.advisorOrder, this.outputType, this.maxRepeatAttempts);
}

}

}
Loading