Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.api.AnthropicCacheOptions;
import org.springframework.ai.anthropic.api.AnthropicCacheTtl;
import org.springframework.ai.anthropic.api.CitationDocument;
import org.springframework.ai.anthropic.api.utils.CacheEligibilityResolver;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
Expand Down Expand Up @@ -322,12 +323,13 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage

List<Generation> generations = new ArrayList<>();
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
CitationContext citationContext = new CitationContext();
for (ContentBlock content : chatCompletion.content()) {
switch (content.type()) {
case TEXT, TEXT_DELTA:
generations.add(new Generation(
AssistantMessage.builder().content(content.text()).properties(Map.of()).build(),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
Generation textGeneration = processTextContent(content, chatCompletion.stopReason(),
citationContext);
generations.add(textGeneration);
break;
case THINKING, THINKING_DELTA:
Map<String, Object> thinkingProperties = new HashMap<>();
Expand Down Expand Up @@ -371,7 +373,101 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
generations.add(toolCallGeneration);
}
return new ChatResponse(generations, this.from(chatCompletion, usage));

// Create response metadata with citation information if present
ChatResponseMetadata.Builder metadataBuilder = ChatResponseMetadata.builder()
.id(chatCompletion.id())
.model(chatCompletion.model())
.usage(usage)
.keyValue("stop-reason", chatCompletion.stopReason())
.keyValue("stop-sequence", chatCompletion.stopSequence())
.keyValue("type", chatCompletion.type());

// Add citation metadata if citations were found
if (citationContext.hasCitations()) {
metadataBuilder.keyValue("citations", citationContext.getAllCitations())
.keyValue("citationCount", citationContext.getTotalCitationCount());
}

ChatResponseMetadata responseMetadata = metadataBuilder.build();

return new ChatResponse(generations, responseMetadata);
}

private Generation processTextContent(ContentBlock content, String stopReason, CitationContext citationContext) {
// Extract citations if present in the content block
if (content.citations() instanceof List) {
try {
@SuppressWarnings("unchecked")
List<Object> citationObjects = (List<Object>) content.citations();

List<Citation> citations = new ArrayList<>();
for (Object citationObj : citationObjects) {
if (citationObj instanceof Map) {
// Convert Map to CitationResponse using manual parsing
AnthropicApi.CitationResponse citationResponse = parseCitationFromMap((Map<?, ?>) citationObj);
citations.add(convertToCitation(citationResponse));
}
else {
logger.warn("Unexpected citation object type: {}. Expected Map but got: {}. Skipping citation.",
citationObj.getClass().getName(), citationObj);
}
}

if (!citations.isEmpty()) {
citationContext.addCitations(citations);
}

}
catch (Exception e) {
logger.warn("Failed to parse citations from content block", e);
}
}

return new Generation(new AssistantMessage(content.text()),
ChatGenerationMetadata.builder().finishReason(stopReason).build());
}

/**
* Parse citation data from Map (typically from JSON deserialization). Assumes all
* required fields are present and of correct types.
* @param citationMap the map containing citation data from API response
* @return parsed CitationResponse
*/
private AnthropicApi.CitationResponse parseCitationFromMap(Map<?, ?> citationMap) {
String type = (String) citationMap.get("type");
String citedText = (String) citationMap.get("cited_text");
Integer documentIndex = (Integer) citationMap.get("document_index");
String documentTitle = (String) citationMap.get("document_title");

Integer startCharIndex = (Integer) citationMap.get("start_char_index");
Integer endCharIndex = (Integer) citationMap.get("end_char_index");
Integer startPageNumber = (Integer) citationMap.get("start_page_number");
Integer endPageNumber = (Integer) citationMap.get("end_page_number");
Integer startBlockIndex = (Integer) citationMap.get("start_block_index");
Integer endBlockIndex = (Integer) citationMap.get("end_block_index");

return new AnthropicApi.CitationResponse(type, citedText, documentIndex, documentTitle, startCharIndex,
endCharIndex, startPageNumber, endPageNumber, startBlockIndex, endBlockIndex);
}

/**
* Convert CitationResponse to Citation object. This method handles the conversion to
* avoid circular dependencies.
*/
private Citation convertToCitation(AnthropicApi.CitationResponse citationResponse) {
return switch (citationResponse.type()) {
case "char_location" -> Citation.ofCharLocation(citationResponse.citedText(),
citationResponse.documentIndex(), citationResponse.documentTitle(),
citationResponse.startCharIndex(), citationResponse.endCharIndex());
case "page_location" -> Citation.ofPageLocation(citationResponse.citedText(),
citationResponse.documentIndex(), citationResponse.documentTitle(),
citationResponse.startPageNumber(), citationResponse.endPageNumber());
case "content_block_location" -> Citation.ofContentBlockLocation(citationResponse.citedText(),
citationResponse.documentIndex(), citationResponse.documentTitle(),
citationResponse.startBlockIndex(), citationResponse.endBlockIndex());
default -> throw new IllegalArgumentException("Unknown citation type: " + citationResponse.type());
};
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
Expand Down Expand Up @@ -479,13 +575,22 @@ Prompt buildRequestPrompt(Prompt prompt) {
// Merge cache options that are Json-ignored
requestOptions.setCacheOptions(runtimeOptions.getCacheOptions() != null ? runtimeOptions.getCacheOptions()
: this.defaultOptions.getCacheOptions());

// Merge citation documents that are Json-ignored
if (runtimeOptions.getCitationDocuments() != null && !runtimeOptions.getCitationDocuments().isEmpty()) {
requestOptions.setCitationDocuments(runtimeOptions.getCitationDocuments());
}
else if (this.defaultOptions.getCitationDocuments() != null) {
requestOptions.setCitationDocuments(this.defaultOptions.getCitationDocuments());
}
}
else {
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
requestOptions.setToolContext(this.defaultOptions.getToolContext());
requestOptions.setCitationDocuments(this.defaultOptions.getCitationDocuments());
}

ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
Expand Down Expand Up @@ -610,12 +715,24 @@ private List<AnthropicMessage> buildMessages(Prompt prompt, CacheEligibilityReso
}
}

// Get citation documents from options
List<CitationDocument> citationDocuments = null;
if (prompt.getOptions() instanceof AnthropicChatOptions anthropicOptions) {
citationDocuments = anthropicOptions.getCitationDocuments();
}

List<AnthropicMessage> result = new ArrayList<>();
for (int i = 0; i < allMessages.size(); i++) {
Message message = allMessages.get(i);
MessageType messageType = message.getMessageType();
if (messageType == MessageType.USER) {
List<ContentBlock> contentBlocks = new ArrayList<>();
// Add citation documents to the FIRST user message only
if (i == 0 && citationDocuments != null && !citationDocuments.isEmpty()) {
for (CitationDocument doc : citationDocuments) {
contentBlocks.add(doc.toContentBlock());
}
}
String content = message.getText();
// For conversation history caching, apply cache control to the
// message immediately before the last user message.
Expand Down Expand Up @@ -823,4 +940,30 @@ public AnthropicChatModel build() {

}

/**
* Context object for tracking citations during response processing. Aggregates
* citations from multiple content blocks in a single response.
*/
class CitationContext {

private final List<Citation> allCitations = new ArrayList<>();

public void addCitations(List<Citation> citations) {
this.allCitations.addAll(citations);
}

public boolean hasCitations() {
return !this.allCitations.isEmpty();
}

public List<Citation> getAllCitations() {
return new ArrayList<>(this.allCitations);
}

public int getTotalCitationCount() {
return this.allCitations.size();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.anthropic.api.AnthropicCacheOptions;
import org.springframework.ai.anthropic.api.CitationDocument;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -63,6 +64,17 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
private @JsonProperty("tool_choice") AnthropicApi.ToolChoice toolChoice;
private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking;

/**
* Documents to be used for citation-based responses. These documents will be
* converted to ContentBlocks and included in the first user message of the request.
* Citations indicating which parts of these documents were used in the response will
* be returned in the response metadata under the "citations" key.
* @see CitationDocument
* @see Citation
*/
@JsonIgnore
private List<CitationDocument> citationDocuments = new ArrayList<>();

@JsonIgnore
private AnthropicCacheOptions cacheOptions = AnthropicCacheOptions.DISABLED;

Expand Down Expand Up @@ -127,6 +139,8 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
.cacheOptions(fromOptions.getCacheOptions())
.citationDocuments(fromOptions.getCitationDocuments() != null
? new ArrayList<>(fromOptions.getCitationDocuments()) : null)
.build();
}

Expand Down Expand Up @@ -283,6 +297,34 @@ public void setHttpHeaders(Map<String, String> httpHeaders) {
this.httpHeaders = httpHeaders;
}

public List<CitationDocument> getCitationDocuments() {
return this.citationDocuments;
}

public void setCitationDocuments(List<CitationDocument> citationDocuments) {
Assert.notNull(citationDocuments, "Citation documents cannot be null");
this.citationDocuments = citationDocuments;
}

/**
* Validate that all citation documents have consistent citation settings. Anthropic
* requires all documents to have citations enabled if any do.
*/
public void validateCitationConsistency() {
if (this.citationDocuments.isEmpty()) {
return;
}

boolean hasEnabledCitations = this.citationDocuments.stream().anyMatch(CitationDocument::isCitationsEnabled);
boolean hasDisabledCitations = this.citationDocuments.stream().anyMatch(doc -> !doc.isCitationsEnabled());

if (hasEnabledCitations && hasDisabledCitations) {
throw new IllegalArgumentException(
"Anthropic Citations API requires all documents to have consistent citation settings. "
+ "Either enable citations for all documents or disable for all documents.");
}
}

@Override
@SuppressWarnings("unchecked")
public AnthropicChatOptions copy() {
Expand All @@ -308,14 +350,16 @@ public boolean equals(Object o) {
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.toolContext, that.toolContext)
&& Objects.equals(this.httpHeaders, that.httpHeaders)
&& Objects.equals(this.cacheOptions, that.cacheOptions);
&& Objects.equals(this.cacheOptions, that.cacheOptions)
&& Objects.equals(this.citationDocuments, that.citationDocuments);
}

@Override
public int hashCode() {
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
this.topK, this.toolChoice, this.thinking, this.toolCallbacks, this.toolNames,
this.internalToolExecutionEnabled, this.toolContext, this.httpHeaders, this.cacheOptions);
this.internalToolExecutionEnabled, this.toolContext, this.httpHeaders, this.cacheOptions,
this.citationDocuments);
}

public static final class Builder {
Expand Down Expand Up @@ -425,7 +469,40 @@ public Builder cacheOptions(AnthropicCacheOptions cacheOptions) {
return this;
}

/**
* Set citation documents for the request.
* @param citationDocuments List of documents to include for citations
* @return Builder for method chaining
*/
public Builder citationDocuments(List<CitationDocument> citationDocuments) {
this.options.setCitationDocuments(citationDocuments);
return this;
}

/**
* Set citation documents from variable arguments.
* @param documents Variable number of CitationDocument objects
* @return Builder for method chaining
*/
public Builder citationDocuments(CitationDocument... documents) {
Assert.notNull(documents, "Citation documents cannot be null");
this.options.citationDocuments.addAll(Arrays.asList(documents));
return this;
}

/**
* Add a single citation document.
* @param document Citation document to add
* @return Builder for method chaining
*/
public Builder addCitationDocument(CitationDocument document) {
Assert.notNull(document, "Citation document cannot be null");
this.options.citationDocuments.add(document);
return this;
}

public AnthropicChatOptions build() {
this.options.validateCitationConsistency();
return this.options;
}

Expand Down
Loading