diff --git a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java index f0a607dbda0..8bf97597388 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/transformer/SummaryMetadataEnricher.java @@ -16,7 +16,6 @@ package org.springframework.ai.transformer; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -94,33 +93,60 @@ public SummaryMetadataEnricher(ChatClient chatClient, List summaryT @Override public List apply(List documents) { + List documentSummaries = summaryDocuments(documents); - List documentSummaries = new ArrayList<>(); - for (Document document : documents) { + for (int i = 0; i < documentSummaries.size(); i++) { + Map summaryMetadata = extractSummaryMetadata(documentSummaries, i); + documents.get(i).getMetadata().putAll(summaryMetadata); + } - var documentContext = document.getFormattedContent(this.metadataMode); + return documents; + } - Prompt prompt = new PromptTemplate(this.summaryTemplate) - .create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContext)); - documentSummaries.add(this.chatClient.call(prompt).getResult().getOutput().getContent()); - } + private List summaryDocuments(List documents) { + return documents.stream() + .map(this::summarySingleDocument) + .toList(); + } - for (int i = 0; i < documentSummaries.size(); i++) { - Map summaryMetadata = new HashMap<>(); - if (i > 0 && this.summaryTypes.contains(SummaryType.PREVIOUS)) { - summaryMetadata.put(PREV_SECTION_SUMMARY_METADATA_KEY, documentSummaries.get(i - 1)); - } - if (i < (documentSummaries.size() - 1) && this.summaryTypes.contains(SummaryType.NEXT)) { - summaryMetadata.put(NEXT_SECTION_SUMMARY_METADATA_KEY, documentSummaries.get(i + 1)); - } - if (this.summaryTypes.contains(SummaryType.CURRENT)) { - summaryMetadata.put(SECTION_SUMMARY_METADATA_KEY, documentSummaries.get(i)); - } + private Map extractSummaryMetadata(List documentSummaries, int nowIndex) { + final int prevIndex = nowIndex - 1; + final int nextIndex = nowIndex + 1; - documents.get(i).getMetadata().putAll(summaryMetadata); + Map summaryMetadata = new HashMap<>(); + + if (nowIndex > 0 && this.summaryTypes.contains(SummaryType.PREVIOUS)) { + summaryMetadata.put(PREV_SECTION_SUMMARY_METADATA_KEY, documentSummaries.get(prevIndex)); + } + if (nowIndex < (documentSummaries.size() - 1) && this.summaryTypes.contains(SummaryType.NEXT)) { + summaryMetadata.put(NEXT_SECTION_SUMMARY_METADATA_KEY, documentSummaries.get(nextIndex)); + } + if (this.summaryTypes.contains(SummaryType.CURRENT)) { + summaryMetadata.put(SECTION_SUMMARY_METADATA_KEY, documentSummaries.get(nowIndex)); } - return documents; + return summaryMetadata; + } + + private String summarySingleDocument(Document document) { + var documentContext = document.getFormattedContent(this.metadataMode); + + Prompt prompt = createPromptByContext(documentContext); + + return getContentFromPrompt(prompt); + } + + private Prompt createPromptByContext(String documentContext) { + return new PromptTemplate(this.summaryTemplate) + .create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContext)); + } + + private String getContentFromPrompt(Prompt prompt) { + return this.chatClient + .call(prompt) + .getResult() + .getOutput() + .getContent(); } }