diff --git a/document-readers/pdf-reader/pom.xml b/document-readers/pdf-reader/pom.xml index eace8bd6d2b..1a44d4ef382 100644 --- a/document-readers/pdf-reader/pom.xml +++ b/document-readers/pdf-reader/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + org.springframework.ai diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java index 11fb9933030..724f7ee5c0a 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/PagePdfDocumentReader.java @@ -16,7 +16,7 @@ package org.springframework.ai.reader.pdf; -import java.awt.*; +import java.awt.Rectangle; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -112,7 +112,7 @@ public List get() { for (PDPage page : this.document.getDocumentCatalog().getPages()) { lastPage = page; if (counter % logFrequency == 0 && counter / logFrequency < 10) { - this.logger.info("Processing PDF page: {}", (counter + 1)); + logger.info("Processing PDF page: {}", (counter + 1)); } counter++; @@ -154,7 +154,7 @@ public List get() { readDocuments.add(toDocument(lastPage, pageTextGroupList.stream().collect(Collectors.joining()), startPageNumber, pageNumber)); } - this.logger.info("Processing {} pages", totalPages); + logger.info("Processing {} pages", totalPages); return readDocuments; } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java index a5943d45d36..cb657238e89 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReader.java @@ -16,7 +16,7 @@ package org.springframework.ai.reader.pdf; -import java.awt.*; +import java.awt.Rectangle; import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -133,7 +133,7 @@ public List get() { List documents = new ArrayList<>(paragraphs.size()); if (!CollectionUtils.isEmpty(paragraphs)) { - this.logger.info("Start processing paragraphs from PDF"); + logger.info("Start processing paragraphs from PDF"); Iterator itr = paragraphs.iterator(); var current = itr.next(); @@ -152,7 +152,7 @@ public List get() { } } } - this.logger.info("End processing paragraphs from PDF"); + logger.info("End processing paragraphs from PDF"); return documents; } diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java index ae5b8588fed..4409a8af1e7 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/aot/PdfReaderRuntimeHints.java @@ -45,9 +45,11 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) { "/org/apache/pdfbox/resources/icc/**", "/org/apache/pdfbox/resources/text/**", "/org/apache/pdfbox/resources/ttf/**", "/org/apache/pdfbox/resources/version.properties"); - for (var pattern : patterns) - for (var resourceMatch : resolver.getResources(pattern)) + for (var pattern : patterns) { + for (var resourceMatch : resolver.getResources(pattern)) { hints.resources().registerResource(resourceMatch); + } + } } catch (IOException e) { diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java index b80ff8e9bb3..d3f804ef51c 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/config/PdfDocumentReaderConfig.java @@ -27,7 +27,7 @@ * * @author Christian Tzolov */ -public class PdfDocumentReaderConfig { +public final class PdfDocumentReaderConfig { public static final int ALL_PAGES = 0; @@ -65,7 +65,7 @@ public static PdfDocumentReaderConfig defaultConfig() { return builder().build(); } - public static class Builder { + public static final class Builder { private int pagesPerDocument = 1; diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/Character.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/Character.java new file mode 100644 index 00000000000..b8b29f1aa96 --- /dev/null +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/Character.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.reader.pdf.layout; + +class Character { + + private char characterValue; + + private int index; + + private boolean isCharacterPartOfPreviousWord; + + private boolean isFirstCharacterOfAWord; + + private boolean isCharacterAtTheBeginningOfNewLine; + + private boolean isCharacterCloseToPreviousWord; + + Character(char characterValue, int index, boolean isCharacterPartOfPreviousWord, boolean isFirstCharacterOfAWord, + boolean isCharacterAtTheBeginningOfNewLine, boolean isCharacterPartOfASentence) { + this.characterValue = characterValue; + this.index = index; + this.isCharacterPartOfPreviousWord = isCharacterPartOfPreviousWord; + this.isFirstCharacterOfAWord = isFirstCharacterOfAWord; + this.isCharacterAtTheBeginningOfNewLine = isCharacterAtTheBeginningOfNewLine; + this.isCharacterCloseToPreviousWord = isCharacterPartOfASentence; + if (ForkPDFLayoutTextStripper.DEBUG) { + System.out.println(this.toString()); + } + } + + public char getCharacterValue() { + return this.characterValue; + } + + public int getIndex() { + return this.index; + } + + public void setIndex(int index) { + this.index = index; + } + + public boolean isCharacterPartOfPreviousWord() { + return this.isCharacterPartOfPreviousWord; + } + + public boolean isFirstCharacterOfAWord() { + return this.isFirstCharacterOfAWord; + } + + public boolean isCharacterAtTheBeginningOfNewLine() { + return this.isCharacterAtTheBeginningOfNewLine; + } + + public boolean isCharacterCloseToPreviousWord() { + return this.isCharacterCloseToPreviousWord; + } + + public String toString() { + String toString = ""; + toString += this.index; + toString += " "; + toString += this.characterValue; + toString += " isCharacterPartOfPreviousWord=" + this.isCharacterPartOfPreviousWord; + toString += " isFirstCharacterOfAWord=" + this.isFirstCharacterOfAWord; + toString += " isCharacterAtTheBeginningOfNewLine=" + this.isCharacterAtTheBeginningOfNewLine; + toString += " isCharacterPartOfASentence=" + this.isCharacterCloseToPreviousWord; + toString += " isCharacterCloseToPreviousWord=" + this.isCharacterCloseToPreviousWord; + return toString; + } + +} diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/CharacterFactory.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/CharacterFactory.java new file mode 100644 index 00000000000..b3e491d1398 --- /dev/null +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/CharacterFactory.java @@ -0,0 +1,109 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.reader.pdf.layout; + +import org.apache.pdfbox.text.TextPosition; + +class CharacterFactory { + + private TextPosition previousTextPosition; + + private boolean firstCharacterOfLineFound; + + private boolean isCharacterPartOfPreviousWord; + + private boolean isFirstCharacterOfAWord; + + private boolean isCharacterAtTheBeginningOfNewLine; + + private boolean isCharacterCloseToPreviousWord; + + CharacterFactory(boolean firstCharacterOfLineFound) { + this.firstCharacterOfLineFound = firstCharacterOfLineFound; + } + + public Character createCharacterFromTextPosition(final TextPosition textPosition, + final TextPosition previousTextPosition) { + this.setPreviousTextPosition(previousTextPosition); + this.isCharacterPartOfPreviousWord = this.isCharacterPartOfPreviousWord(textPosition); + this.isFirstCharacterOfAWord = this.isFirstCharacterOfAWord(textPosition); + this.isCharacterAtTheBeginningOfNewLine = this.isCharacterAtTheBeginningOfNewLine(textPosition); + this.isCharacterCloseToPreviousWord = this.isCharacterCloseToPreviousWord(textPosition); + char character = this.getCharacterFromTextPosition(textPosition); + int index = (int) textPosition.getX() / ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT; + return new Character(character, index, this.isCharacterPartOfPreviousWord, this.isFirstCharacterOfAWord, + this.isCharacterAtTheBeginningOfNewLine, this.isCharacterCloseToPreviousWord); + } + + private boolean isCharacterAtTheBeginningOfNewLine(final TextPosition textPosition) { + if (!this.firstCharacterOfLineFound) { + return true; + } + TextPosition previousTextPosition = this.getPreviousTextPosition(); + float previousTextYPosition = previousTextPosition.getY(); + return (Math.round(textPosition.getY()) < Math.round(previousTextYPosition)); + } + + private boolean isFirstCharacterOfAWord(final TextPosition textPosition) { + if (!this.firstCharacterOfLineFound) { + return true; + } + double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); + return (numberOfSpaces > 1) || this.isCharacterAtTheBeginningOfNewLine(textPosition); + } + + private boolean isCharacterCloseToPreviousWord(final TextPosition textPosition) { + if (!this.firstCharacterOfLineFound) { + return false; + } + double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); + return (numberOfSpaces > 1 && numberOfSpaces <= ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT); + } + + private boolean isCharacterPartOfPreviousWord(final TextPosition textPosition) { + TextPosition previousTextPosition = this.getPreviousTextPosition(); + if (previousTextPosition.getUnicode().equals(" ")) { + return false; + } + double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(previousTextPosition, textPosition); + return (numberOfSpaces <= 1); + } + + private double numberOfSpacesBetweenTwoCharacters(final TextPosition textPosition1, + final TextPosition textPosition2) { + double previousTextXPosition = textPosition1.getX(); + double previousTextWidth = textPosition1.getWidth(); + double previousTextEndXPosition = (previousTextXPosition + previousTextWidth); + double numberOfSpaces = Math.abs(Math.round(textPosition2.getX() - previousTextEndXPosition)); + return numberOfSpaces; + } + + private char getCharacterFromTextPosition(final TextPosition textPosition) { + String string = textPosition.getUnicode(); + char character = string.charAt(0); + return character; + } + + private TextPosition getPreviousTextPosition() { + return this.previousTextPosition; + } + + private void setPreviousTextPosition(final TextPosition previousTextPosition) { + this.previousTextPosition = previousTextPosition; + } + +} diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java index ea1980ff667..abb32e140cd 100644 --- a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/ForkPDFLayoutTextStripper.java @@ -13,17 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/* Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ package org.springframework.ai.reader.pdf.layout; @@ -217,274 +206,3 @@ private List getTextLineList() { } } - -class TextLine { - - private static final char SPACE_CHARACTER = ' '; - - private int lineLength; - - private String line; - - private int lastIndex; - - public TextLine(int lineLength) { - this.line = ""; - this.lineLength = lineLength / ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT; - this.completeLineWithSpaces(); - } - - public void writeCharacterAtIndex(final Character character) { - character.setIndex(this.computeIndexForCharacter(character)); - int index = character.getIndex(); - char characterValue = character.getCharacterValue(); - if (this.indexIsInBounds(index) && this.line.charAt(index) == SPACE_CHARACTER) { - this.line = this.line.substring(0, index) + characterValue - + this.line.substring(index + 1, this.getLineLength()); - } - } - - public int getLineLength() { - return this.lineLength; - } - - public String getLine() { - return this.line; - } - - private int computeIndexForCharacter(final Character character) { - int index = character.getIndex(); - boolean isCharacterPartOfPreviousWord = character.isCharacterPartOfPreviousWord(); - boolean isCharacterAtTheBeginningOfNewLine = character.isCharacterAtTheBeginningOfNewLine(); - boolean isCharacterCloseToPreviousWord = character.isCharacterCloseToPreviousWord(); - - if (!this.indexIsInBounds(index)) { - return -1; - } - else { - if (isCharacterPartOfPreviousWord && !isCharacterAtTheBeginningOfNewLine) { - index = this.findMinimumIndexWithSpaceCharacterFromIndex(index); - } - else if (isCharacterCloseToPreviousWord) { - if (this.line.charAt(index) != SPACE_CHARACTER) { - index = index + 1; - } - else { - index = this.findMinimumIndexWithSpaceCharacterFromIndex(index) + 1; - } - } - index = this.getNextValidIndex(index, isCharacterPartOfPreviousWord); - return index; - } - } - - private boolean isSpaceCharacterAtIndex(int index) { - return this.line.charAt(index) != SPACE_CHARACTER; - } - - private boolean isNewIndexGreaterThanLastIndex(int index) { - int lastIndex = this.getLastIndex(); - return (index > lastIndex); - } - - private int getNextValidIndex(int index, boolean isCharacterPartOfPreviousWord) { - int nextValidIndex = index; - int lastIndex = this.getLastIndex(); - if (!this.isNewIndexGreaterThanLastIndex(index)) { - nextValidIndex = lastIndex + 1; - } - if (!isCharacterPartOfPreviousWord && this.isSpaceCharacterAtIndex(index - 1)) { - nextValidIndex = nextValidIndex + 1; - } - this.setLastIndex(nextValidIndex); - return nextValidIndex; - } - - private int findMinimumIndexWithSpaceCharacterFromIndex(int index) { - int newIndex = index; - while (newIndex >= 0 && this.line.charAt(newIndex) == SPACE_CHARACTER) { - newIndex = newIndex - 1; - } - return newIndex + 1; - } - - private boolean indexIsInBounds(int index) { - return (index >= 0 && index < this.lineLength); - } - - private void completeLineWithSpaces() { - for (int i = 0; i < this.getLineLength(); ++i) { - this.line += SPACE_CHARACTER; - } - } - - private int getLastIndex() { - return this.lastIndex; - } - - private void setLastIndex(int lastIndex) { - this.lastIndex = lastIndex; - } - -} - -class Character { - - private char characterValue; - - private int index; - - private boolean isCharacterPartOfPreviousWord; - - private boolean isFirstCharacterOfAWord; - - private boolean isCharacterAtTheBeginningOfNewLine; - - private boolean isCharacterCloseToPreviousWord; - - public Character(char characterValue, int index, boolean isCharacterPartOfPreviousWord, - boolean isFirstCharacterOfAWord, boolean isCharacterAtTheBeginningOfNewLine, - boolean isCharacterPartOfASentence) { - this.characterValue = characterValue; - this.index = index; - this.isCharacterPartOfPreviousWord = isCharacterPartOfPreviousWord; - this.isFirstCharacterOfAWord = isFirstCharacterOfAWord; - this.isCharacterAtTheBeginningOfNewLine = isCharacterAtTheBeginningOfNewLine; - this.isCharacterCloseToPreviousWord = isCharacterPartOfASentence; - if (ForkPDFLayoutTextStripper.DEBUG) { - System.out.println(this.toString()); - } - } - - public char getCharacterValue() { - return this.characterValue; - } - - public int getIndex() { - return this.index; - } - - public void setIndex(int index) { - this.index = index; - } - - public boolean isCharacterPartOfPreviousWord() { - return this.isCharacterPartOfPreviousWord; - } - - public boolean isFirstCharacterOfAWord() { - return this.isFirstCharacterOfAWord; - } - - public boolean isCharacterAtTheBeginningOfNewLine() { - return this.isCharacterAtTheBeginningOfNewLine; - } - - public boolean isCharacterCloseToPreviousWord() { - return this.isCharacterCloseToPreviousWord; - } - - public String toString() { - String toString = ""; - toString += this.index; - toString += " "; - toString += this.characterValue; - toString += " isCharacterPartOfPreviousWord=" + this.isCharacterPartOfPreviousWord; - toString += " isFirstCharacterOfAWord=" + this.isFirstCharacterOfAWord; - toString += " isCharacterAtTheBeginningOfNewLine=" + this.isCharacterAtTheBeginningOfNewLine; - toString += " isCharacterPartOfASentence=" + this.isCharacterCloseToPreviousWord; - toString += " isCharacterCloseToPreviousWord=" + this.isCharacterCloseToPreviousWord; - return toString; - } - -} - -class CharacterFactory { - - private TextPosition previousTextPosition; - - private boolean firstCharacterOfLineFound; - - private boolean isCharacterPartOfPreviousWord; - - private boolean isFirstCharacterOfAWord; - - private boolean isCharacterAtTheBeginningOfNewLine; - - private boolean isCharacterCloseToPreviousWord; - - public CharacterFactory(boolean firstCharacterOfLineFound) { - this.firstCharacterOfLineFound = firstCharacterOfLineFound; - } - - public Character createCharacterFromTextPosition(final TextPosition textPosition, - final TextPosition previousTextPosition) { - this.setPreviousTextPosition(previousTextPosition); - this.isCharacterPartOfPreviousWord = this.isCharacterPartOfPreviousWord(textPosition); - this.isFirstCharacterOfAWord = this.isFirstCharacterOfAWord(textPosition); - this.isCharacterAtTheBeginningOfNewLine = this.isCharacterAtTheBeginningOfNewLine(textPosition); - this.isCharacterCloseToPreviousWord = this.isCharacterCloseToPreviousWord(textPosition); - char character = this.getCharacterFromTextPosition(textPosition); - int index = (int) textPosition.getX() / ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT; - return new Character(character, index, this.isCharacterPartOfPreviousWord, this.isFirstCharacterOfAWord, - this.isCharacterAtTheBeginningOfNewLine, this.isCharacterCloseToPreviousWord); - } - - private boolean isCharacterAtTheBeginningOfNewLine(final TextPosition textPosition) { - if (!this.firstCharacterOfLineFound) { - return true; - } - TextPosition previousTextPosition = this.getPreviousTextPosition(); - float previousTextYPosition = previousTextPosition.getY(); - return (Math.round(textPosition.getY()) < Math.round(previousTextYPosition)); - } - - private boolean isFirstCharacterOfAWord(final TextPosition textPosition) { - if (!this.firstCharacterOfLineFound) { - return true; - } - double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); - return (numberOfSpaces > 1) || this.isCharacterAtTheBeginningOfNewLine(textPosition); - } - - private boolean isCharacterCloseToPreviousWord(final TextPosition textPosition) { - if (!this.firstCharacterOfLineFound) { - return false; - } - double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(this.previousTextPosition, textPosition); - return (numberOfSpaces > 1 && numberOfSpaces <= ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT); - } - - private boolean isCharacterPartOfPreviousWord(final TextPosition textPosition) { - TextPosition previousTextPosition = this.getPreviousTextPosition(); - if (previousTextPosition.getUnicode().equals(" ")) { - return false; - } - double numberOfSpaces = this.numberOfSpacesBetweenTwoCharacters(previousTextPosition, textPosition); - return (numberOfSpaces <= 1); - } - - private double numberOfSpacesBetweenTwoCharacters(final TextPosition textPosition1, - final TextPosition textPosition2) { - double previousTextXPosition = textPosition1.getX(); - double previousTextWidth = textPosition1.getWidth(); - double previousTextEndXPosition = (previousTextXPosition + previousTextWidth); - double numberOfSpaces = Math.abs(Math.round(textPosition2.getX() - previousTextEndXPosition)); - return numberOfSpaces; - } - - private char getCharacterFromTextPosition(final TextPosition textPosition) { - String string = textPosition.getUnicode(); - char character = string.charAt(0); - return character; - } - - private TextPosition getPreviousTextPosition() { - return this.previousTextPosition; - } - - private void setPreviousTextPosition(final TextPosition previousTextPosition) { - this.previousTextPosition = previousTextPosition; - } - -} diff --git a/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/TextLine.java b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/TextLine.java new file mode 100644 index 00000000000..cd6e0002c43 --- /dev/null +++ b/document-readers/pdf-reader/src/main/java/org/springframework/ai/reader/pdf/layout/TextLine.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.reader.pdf.layout; + +class TextLine { + + private static final char SPACE_CHARACTER = ' '; + + private int lineLength; + + private String line; + + private int lastIndex; + + TextLine(int lineLength) { + this.line = ""; + this.lineLength = lineLength / ForkPDFLayoutTextStripper.OUTPUT_SPACE_CHARACTER_WIDTH_IN_PT; + this.completeLineWithSpaces(); + } + + public void writeCharacterAtIndex(final Character character) { + character.setIndex(this.computeIndexForCharacter(character)); + int index = character.getIndex(); + char characterValue = character.getCharacterValue(); + if (this.indexIsInBounds(index) && this.line.charAt(index) == SPACE_CHARACTER) { + this.line = this.line.substring(0, index) + characterValue + + this.line.substring(index + 1, this.getLineLength()); + } + } + + public int getLineLength() { + return this.lineLength; + } + + public String getLine() { + return this.line; + } + + private int computeIndexForCharacter(final Character character) { + int index = character.getIndex(); + boolean isCharacterPartOfPreviousWord = character.isCharacterPartOfPreviousWord(); + boolean isCharacterAtTheBeginningOfNewLine = character.isCharacterAtTheBeginningOfNewLine(); + boolean isCharacterCloseToPreviousWord = character.isCharacterCloseToPreviousWord(); + + if (!this.indexIsInBounds(index)) { + return -1; + } + else { + if (isCharacterPartOfPreviousWord && !isCharacterAtTheBeginningOfNewLine) { + index = this.findMinimumIndexWithSpaceCharacterFromIndex(index); + } + else if (isCharacterCloseToPreviousWord) { + if (this.line.charAt(index) != SPACE_CHARACTER) { + index = index + 1; + } + else { + index = this.findMinimumIndexWithSpaceCharacterFromIndex(index) + 1; + } + } + index = this.getNextValidIndex(index, isCharacterPartOfPreviousWord); + return index; + } + } + + private boolean isSpaceCharacterAtIndex(int index) { + return this.line.charAt(index) != SPACE_CHARACTER; + } + + private boolean isNewIndexGreaterThanLastIndex(int index) { + int lastIndex = this.getLastIndex(); + return (index > lastIndex); + } + + private int getNextValidIndex(int index, boolean isCharacterPartOfPreviousWord) { + int nextValidIndex = index; + int lastIndex = this.getLastIndex(); + if (!this.isNewIndexGreaterThanLastIndex(index)) { + nextValidIndex = lastIndex + 1; + } + if (!isCharacterPartOfPreviousWord && this.isSpaceCharacterAtIndex(index - 1)) { + nextValidIndex = nextValidIndex + 1; + } + this.setLastIndex(nextValidIndex); + return nextValidIndex; + } + + private int findMinimumIndexWithSpaceCharacterFromIndex(int index) { + int newIndex = index; + while (newIndex >= 0 && this.line.charAt(newIndex) == SPACE_CHARACTER) { + newIndex = newIndex - 1; + } + return newIndex + 1; + } + + private boolean indexIsInBounds(int index) { + return (index >= 0 && index < this.lineLength); + } + + private void completeLineWithSpaces() { + for (int i = 0; i < this.getLineLength(); ++i) { + this.line += SPACE_CHARACTER; + } + } + + private int getLastIndex() { + return this.lastIndex; + } + + private void setLastIndex(int lastIndex) { + this.lastIndex = lastIndex; + } + +} diff --git a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java index 5b45f14de8a..b514f690e11 100644 --- a/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java +++ b/document-readers/pdf-reader/src/test/java/org/springframework/ai/reader/pdf/ParagraphPdfDocumentReaderTests.java @@ -31,20 +31,20 @@ public class ParagraphPdfDocumentReaderTests { @Test public void testPdfWithoutToc() { - assertThatThrownBy(() -> { - - new ParagraphPdfDocumentReader("classpath:/sample1.pdf", - PdfDocumentReaderConfig.builder() - .withPageTopMargin(0) - .withPageBottomMargin(0) - .withPageExtractedTextFormatter(ExtractedTextFormatter.builder() - .withNumberOfTopTextLinesToDelete(0) - .withNumberOfBottomTextLinesToDelete(3) - .withNumberOfTopPagesToSkipBeforeDelete(0) - .build()) - .withPagesPerDocument(1) - .build()); - }).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> + + new ParagraphPdfDocumentReader("classpath:/sample1.pdf", + PdfDocumentReaderConfig.builder() + .withPageTopMargin(0) + .withPageBottomMargin(0) + .withPageExtractedTextFormatter(ExtractedTextFormatter.builder() + .withNumberOfTopTextLinesToDelete(0) + .withNumberOfBottomTextLinesToDelete(3) + .withNumberOfTopPagesToSkipBeforeDelete(0) + .build()) + .withPagesPerDocument(1) + .build())) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining( "Document outline (e.g. TOC) is null. Make sure the PDF document has a table of contents (TOC). If not, consider the PagePdfDocumentReader or the TikaDocumentReader instead."); diff --git a/models/spring-ai-anthropic/pom.xml b/models/spring-ai-anthropic/pom.xml index b2461539485..74663049a06 100644 --- a/models/spring-ai-anthropic/pom.xml +++ b/models/spring-ai-anthropic/pom.xml @@ -37,6 +37,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + @@ -106,4 +110,4 @@ - \ No newline at end of file + diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index d67431a9854..fc4da3e8bd6 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -272,9 +272,7 @@ public Flux stream(Prompt prompt) { return Mono.just(chatResponse); }) .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) + .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on @@ -292,10 +290,8 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) { List generations = chatCompletion.content() .stream() .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) - .map(content -> { - return new Generation(new AssistantMessage(content.text(), Map.of()), - ChatGenerationMetadata.from(chatCompletion.stopReason(), null)); - }) + .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), + ChatGenerationMetadata.from(chatCompletion.stopReason(), null))) .toList(); List allGenerations = new ArrayList<>(generations); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java index 71a47d1e0db..1f2da011942 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/aot/AnthropicRuntimeHints.java @@ -37,8 +37,9 @@ public class AnthropicRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(AnthropicApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(AnthropicApi.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 35fa4faf6fb..fb51c348e4a 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -256,8 +256,10 @@ public String getName() { public enum Role { // @formatter:off - @JsonProperty("user") USER, - @JsonProperty("assistant") ASSISTANT + @JsonProperty("user") + USER, + @JsonProperty("assistant") + ASSISTANT // @formatter:on } @@ -318,7 +320,7 @@ public enum EventType { /** * Artifically created event to aggregate tool use events. */ - TOOL_USE_AGGREATE; + TOOL_USE_AGGREATE } @@ -383,7 +385,8 @@ public interface StreamEvent { * optionally return results back to the model using tool_result content blocks. */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest( // @formatter:off + public record ChatCompletionRequest( + // @formatter:off @JsonProperty("model") String model, @JsonProperty("messages") List messages, @JsonProperty("system") String system, @@ -428,7 +431,7 @@ public record Metadata(@JsonProperty("user_id") String userId) { } - public static class ChatCompletionRequestBuilder { + public static final class ChatCompletionRequestBuilder { private String model; @@ -559,9 +562,10 @@ public ChatCompletionRequest build() { * types. */ @JsonInclude(Include.NON_NULL) - public record AnthropicMessage( // @formatter:off - @JsonProperty("content") List content, - @JsonProperty("role") Role role) { + public record AnthropicMessage( + // @formatter:off + @JsonProperty("content") List content, + @JsonProperty("role") Role role) { // @formatter:on } @@ -574,7 +578,8 @@ public record AnthropicMessage( // @formatter:off * responses. */ @JsonInclude(Include.NON_NULL) - public record ContentBlock( // @formatter:off + public record ContentBlock( + // @formatter:off @JsonProperty("type") Type type, @JsonProperty("source") Source source, @JsonProperty("text") String text, @@ -682,7 +687,8 @@ public String getValue() { * @param data The base64-encoded data of the content. */ @JsonInclude(Include.NON_NULL) - public record Source( // @formatter:off + public record Source( + // @formatter:off @JsonProperty("type") String type, @JsonProperty("media_type") String mediaType, @JsonProperty("data") String data) { @@ -701,7 +707,8 @@ public Source(String mediaType, String data) { /////////////////////////////////////// @JsonInclude(Include.NON_NULL) - public record Tool(// @formatter:off + public record Tool( + // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @JsonProperty("input_schema") Map inputSchema) { @@ -724,7 +731,8 @@ public record Tool(// @formatter:off * @param usage Input and output token usage. */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionResponse( // @formatter:off + public record ChatCompletionResponse( + // @formatter:off @JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("role") Role role, @@ -745,9 +753,10 @@ public record ChatCompletionResponse( // @formatter:off * @param outputTokens The number of output tokens which were used. completion). */ @JsonInclude(Include.NON_NULL) - public record Usage( // @formatter:off - @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { + public record Usage( + // @formatter:off + @JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens) { // @formatter:off } @@ -828,10 +837,11 @@ public String toString() { // MESSAGE START EVENT @JsonInclude(Include.NON_NULL) - public record ContentBlockStartEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, - @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { + public record ContentBlockStartEvent( + // @formatter:off + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, + @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @@ -854,17 +864,18 @@ public record ContentBlockText( @JsonProperty("type") String type, @JsonProperty("text") String text) implements ContentBlockBody { } - }// @formatter:on + } + // @formatter:on // MESSAGE DELTA EVENT @JsonInclude(Include.NON_NULL) - public record ContentBlockDeltaEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, + public record ContentBlockDeltaEvent( + // @formatter:off + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, @JsonProperty("delta") ContentBlockDeltaBody delta) implements StreamEvent { - @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @JsonSubTypes({ @JsonSubTypes.Type(value = ContentBlockDeltaText.class, name = "text_delta"), @@ -884,66 +895,78 @@ public record ContentBlockDeltaJson( @JsonProperty("type") String type, @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { } - }// @formatter:on + } + // @formatter:on // MESSAGE STOP EVENT @JsonInclude(Include.NON_NULL) - public record ContentBlockStopEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index) implements StreamEvent { - }// @formatter:on + public record ContentBlockStopEvent( + // @formatter:off + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index) implements StreamEvent { + } + // @formatter:on @JsonInclude(Include.NON_NULL) public record MessageStartEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { - }// @formatter:on + @JsonProperty("type") EventType type, + @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { + } + // @formatter:on @JsonInclude(Include.NON_NULL) - public record MessageDeltaEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("delta") MessageDelta delta, - @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { + public record MessageDeltaEvent( + // @formatter:off + @JsonProperty("type") EventType type, + @JsonProperty("delta") MessageDelta delta, + @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { - @JsonInclude(Include.NON_NULL) - public record MessageDelta( - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop_sequence") String stopSequence) { + @JsonInclude(Include.NON_NULL) + public record MessageDelta( + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("stop_sequence") String stopSequence) { } @JsonInclude(Include.NON_NULL) public record MessageDeltaUsage( - @JsonProperty("output_tokens") Integer outputTokens) { - } - }// @formatter:on + @JsonProperty("output_tokens") Integer outputTokens) { + } + } + // @formatter:on @JsonInclude(Include.NON_NULL) - public record MessageStopEvent(// @formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { - }// @formatter:on + public record MessageStopEvent( + // @formatter:off + @JsonProperty("type") EventType type) implements StreamEvent { + } + // @formatter:on /////////////////////////////////////// /// ERROR EVENT /////////////////////////////////////// @JsonInclude(Include.NON_NULL) - public record ErrorEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("error") Error error) implements StreamEvent { + public record ErrorEvent( + // @formatter:off + @JsonProperty("type") EventType type, + @JsonProperty("error") Error error) implements StreamEvent { - @JsonInclude(Include.NON_NULL) - public record Error( - @JsonProperty("type") String type, - @JsonProperty("message") String message) { - } - }// @formatter:on + @JsonInclude(Include.NON_NULL) + public record Error( + @JsonProperty("type") String type, + @JsonProperty("message") String message) { + } + } + // @formatter:on /////////////////////////////////////// /// PING EVENT /////////////////////////////////////// @JsonInclude(Include.NON_NULL) - public record PingEvent(// @formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { - }// @formatter:on + public record PingEvent( + // @formatter:off + @JsonProperty("type") EventType type) implements StreamEvent { + } + // @formatter:on } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index 677bdb2e49a..ec08d203089 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -178,6 +178,7 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { } } else if (event.type().equals(EventType.MESSAGE_STOP)) { + // pass through } else { contentBlockReference.get().withType(event.type().name()).withContent(List.of()); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java index 8af45829870..151a656480b 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java @@ -65,7 +65,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java index 9ea40d1c800..b6c972fc423 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/XmlHelper.java @@ -35,7 +35,7 @@ /** * @author Christian Tzolov */ -public class XmlHelper { +public final class XmlHelper { // Regular expression to match XML block between and // tags @@ -46,6 +46,10 @@ public class XmlHelper { private static final XmlMapper xmlMapper = new XmlMapper(); + private XmlHelper() { + + } + public static String extractFunctionCallsXmlBlock(String text) { if (!StringUtils.hasText(text)) { return ""; @@ -128,7 +132,8 @@ public record Parameter( @JsonProperty("description") String description) { } } - } // @formatter:on + } + // @formatter:on @JsonInclude(Include.NON_NULL) // @formatter:off @JacksonXmlRootElement(localName = "function_calls") diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index d93a84e004c..b4038d6acb0 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -93,7 +93,7 @@ void listOutputConverterString() { .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() - .entity(new ParameterizedTypeReference>() {}); + .entity(new ParameterizedTypeReference<>() { }); // @formatter:on logger.info(collection.toString()); diff --git a/models/spring-ai-azure-openai/pom.xml b/models/spring-ai-azure-openai/pom.xml index 101b9e508d7..35f8511d226 100644 --- a/models/spring-ai-azure-openai/pom.xml +++ b/models/spring-ai-azure-openai/pom.xml @@ -35,6 +35,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java index b79e2588518..2d5f481cfe9 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionOptions.java @@ -45,7 +45,7 @@ public class AzureOpenAiAudioTranscriptionOptions implements AudioTranscriptionO * The deployment name as defined in Azure Open AI Studio when creating a deployment * backed by an Azure OpenAI base model. */ - private @JsonProperty(value = "deployment_name") String deploymentName; + private @JsonProperty("deployment_name") String deploymentName; /** * The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. @@ -138,41 +138,53 @@ public int hashCode() { @Override public boolean equals(Object obj) { - if (this == obj) + if (this == obj) { return true; - if (obj == null) + } + if (obj == null) { return false; - if (getClass() != obj.getClass()) + } + if (getClass() != obj.getClass()) { return false; + } AzureOpenAiAudioTranscriptionOptions other = (AzureOpenAiAudioTranscriptionOptions) obj; if (this.model == null) { - if (other.model != null) + if (other.model != null) { return false; + } } - else if (!this.model.equals(other.model)) + else if (!this.model.equals(other.model)) { return false; + } if (this.prompt == null) { - if (other.prompt != null) + if (other.prompt != null) { return false; + } } - else if (!this.prompt.equals(other.prompt)) + else if (!this.prompt.equals(other.prompt)) { return false; + } if (this.language == null) { - if (other.language != null) + if (other.language != null) { return false; + } } - else if (!this.language.equals(other.language)) + else if (!this.language.equals(other.language)) { return false; + } if (this.responseFormat == null) { - return other.responseFormat==null; + return other.responseFormat == null; + } + else { + return this.responseFormat.equals(other.responseFormat); } - else return this.responseFormat.equals(other.responseFormat); } public enum WhisperModel { // @formatter:off - @JsonProperty("whisper") WHISPER("whisper"); + @JsonProperty("whisper") + WHISPER("whisper"); // @formatter:on public final String value; @@ -190,11 +202,16 @@ public String getValue() { public enum TranscriptResponseFormat { // @formatter:off - @JsonProperty("json") JSON(AudioTranscriptionFormat.JSON, StructuredResponse.class), - @JsonProperty("text") TEXT(AudioTranscriptionFormat.TEXT, String.class), - @JsonProperty("srt") SRT(AudioTranscriptionFormat.SRT, String.class), - @JsonProperty("verbose_json") VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON, StructuredResponse.class), - @JsonProperty("vtt") VTT(AudioTranscriptionFormat.VTT, String.class); + @JsonProperty("json") + JSON(AudioTranscriptionFormat.JSON, StructuredResponse.class), + @JsonProperty("text") + TEXT(AudioTranscriptionFormat.TEXT, String.class), + @JsonProperty("srt") + SRT(AudioTranscriptionFormat.SRT, String.class), + @JsonProperty("verbose_json") + VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON, StructuredResponse.class), + @JsonProperty("vtt") + VTT(AudioTranscriptionFormat.VTT, String.class); public final AudioTranscriptionFormat value; @@ -217,8 +234,10 @@ public Class getResponseType() { public enum GranularityType { // @formatter:off - @JsonProperty("word") WORD(AudioTranscriptionTimestampGranularity.WORD), - @JsonProperty("segment") SEGMENT(AudioTranscriptionTimestampGranularity.SEGMENT); + @JsonProperty("word") + WORD(AudioTranscriptionTimestampGranularity.WORD), + @JsonProperty("segment") + SEGMENT(AudioTranscriptionTimestampGranularity.SEGMENT); // @formatter:on public final AudioTranscriptionTimestampGranularity value; diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 77c2ea0b245..203981d902a 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -290,9 +290,10 @@ public Flux stream(Prompt prompt) { return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); } - Flux flux = Flux.just(chatResponse).doOnError(observation::error).doFinally(s -> { - observation.stop(); - }).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + Flux flux = Flux.just(chatResponse) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); return new MessageAggregator().aggregate(flux, observationContext::setResponse); }); @@ -416,7 +417,7 @@ private List fromSpringAiMessage(Message message) { return List.of(new ChatRequestUserMessage(items)); case SYSTEM: return List.of(new ChatRequestSystemMessage(message.getContent())); - case ASSISTANT: { + case ASSISTANT: AssistantMessage assistantMessage = (AssistantMessage) message; List toolCalls = null; if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { @@ -430,20 +431,17 @@ private List fromSpringAiMessage(Message message) { var azureAssistantMessage = new ChatRequestAssistantMessage(message.getContent()); azureAssistantMessage.setToolCalls(toolCalls); return List.of(azureAssistantMessage); - } - case TOOL: { + case TOOL: ToolResponseMessage toolMessage = (ToolResponseMessage) message; - toolMessage.getResponses().forEach(response -> { - Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"); - }); + toolMessage.getResponses() + .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); return toolMessage.getResponses() .stream() .map(tr -> new ChatRequestToolMessage(tr.responseData(), tr.id())) .map(crtm -> ((ChatRequestMessage) crtm)) .toList(); - } default: throw new IllegalArgumentException("Unknown message type " + message.getMessageType()); } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index f890f1266ab..89eb33556a4 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -49,7 +49,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio /** * The maximum number of tokens to generate. */ - @JsonProperty(value = "max_tokens") + @JsonProperty("max_tokens") private Integer maxTokens; /** @@ -59,7 +59,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * temperature and top_p for the same completions request as the interaction of these * two settings is difficult to predict. */ - @JsonProperty(value = "temperature") + @JsonProperty("temperature") private Double temperature; /** @@ -70,7 +70,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * temperature and top_p for the same completions request as the interaction of these * two settings is difficult to predict. */ - @JsonProperty(value = "top_p") + @JsonProperty("top_p") private Double topP; /** @@ -80,14 +80,14 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * minimum and maximum values corresponding to a full ban or exclusive selection of a * token, respectively. The exact behavior of a given bias score varies by model. */ - @JsonProperty(value = "logit_bias") + @JsonProperty("logit_bias") private Map logitBias; /** * An identifier for the caller or end user of the operation. This may be used for * tracking or rate-limiting purposes. */ - @JsonProperty(value = "user") + @JsonProperty("user") private String user; /** @@ -96,13 +96,13 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * quickly consume your token quota. Use carefully and ensure reasonable settings for * max_tokens and stop. */ - @JsonProperty(value = "n") + @JsonProperty("n") private Integer n; /** * A collection of textual sequences that will end completions generation. */ - @JsonProperty(value = "stop") + @JsonProperty("stop") private List stop; /** @@ -111,7 +111,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * likely to appear when they already exist and increase the model's likelihood to * output new topics. */ - @JsonProperty(value = "presence_penalty") + @JsonProperty("presence_penalty") private Double presencePenalty; /** @@ -120,14 +120,14 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * likely to appear as their frequency increases and decrease the likelihood of the * model repeating the same statements verbatim. */ - @JsonProperty(value = "frequency_penalty") + @JsonProperty("frequency_penalty") private Double frequencyPenalty; /** * The deployment name as defined in Azure Open AI Studio when creating a deployment * backed by an Azure OpenAI base model. */ - @JsonProperty(value = "deployment_name") + @JsonProperty("deployment_name") private String deploymentName; /** @@ -171,7 +171,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * Seed value for deterministic sampling such that the same seed and parameters return * the same result. */ - @JsonProperty(value = "seed") + @JsonProperty("seed") private Long seed; /** @@ -179,7 +179,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * the log probabilities of each output token returned in the `content` of `message`. * This option is currently not available on the `gpt-4-vision-preview` model. */ - @JsonProperty(value = "log_probs") + @JsonProperty("log_probs") private Boolean logprobs; /* @@ -187,7 +187,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio * each token position, each with an associated log probability. `logprobs` must be * set to `true` if this parameter is used. */ - @JsonProperty(value = "top_log_probs") + @JsonProperty("top_log_probs") private Integer topLogProbs; /* diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index 9b1b466efac..5c772a3b826 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -44,8 +44,6 @@ import org.springframework.ai.util.JacksonUtils; import org.springframework.util.Assert; -import static java.lang.String.format; - /** * {@link ImageModel} implementation for {@literal Microsoft Azure AI} backed by * {@link OpenAIClient}. @@ -92,15 +90,15 @@ public AzureOpenAiImageOptions getDefaultOptions() { public ImageResponse call(ImagePrompt imagePrompt) { ImageGenerationOptions imageGenerationOptions = toOpenAiImageOptions(imagePrompt); String deploymentOrModelName = getDeploymentName(imagePrompt); - if (this.logger.isTraceEnabled()) { - this.logger.trace("Azure ImageGenerationOptions call {} with the following options : {} ", - deploymentOrModelName, toPrettyJson(imageGenerationOptions)); + if (logger.isTraceEnabled()) { + logger.trace("Azure ImageGenerationOptions call {} with the following options : {} ", deploymentOrModelName, + toPrettyJson(imageGenerationOptions)); } var images = this.openAIClient.getImageGenerations(deploymentOrModelName, imageGenerationOptions); - if (this.logger.isTraceEnabled()) { - this.logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); + if (logger.isTraceEnabled()) { + logger.trace("Azure ImageGenerations: {}", toPrettyJson(images)); } List imageGenerations = images.getData().stream().map(entry -> { @@ -154,8 +152,8 @@ private String getDeploymentName(ImagePrompt prompt) { private ImageGenerationOptions toOpenAiImageOptions(ImagePrompt prompt) { if (prompt.getInstructions().size() > 1) { - throw new RuntimeException(format("implementation support 1 image instruction only, found %s", - prompt.getInstructions().size())); + throw new RuntimeException(java.lang.String + .format("implementation support 1 image instruction only, found %s", prompt.getInstructions().size())); } if (prompt.getInstructions().isEmpty()) { throw new RuntimeException("please provide image instruction, current is empty"); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java index 2e6d13c572f..c0743141d68 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageOptions.java @@ -45,14 +45,14 @@ public class AzureOpenAiImageOptions implements ImageOptions { /** * The model dall-e-3 or dall-e-2 By default dall-e-3 */ - @JsonProperty(value = "model") + @JsonProperty("model") private String model = ImageModel.DALL_E_3.value; /** * The deployment name as defined in Azure Open AI Studio when creating a deployment * backed by an Azure OpenAI base model. */ - @JsonProperty(value = "deployment_name") + @JsonProperty("deployment_name") private String deploymentName; /** @@ -255,7 +255,7 @@ public String getValue() { } - public static class Builder { + public final static class Builder { private final AzureOpenAiImageOptions options; diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java index 82c1f57b5d1..ad631f741d3 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java @@ -48,7 +48,7 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class MergeUtils { +public final class MergeUtils { private static final Class[] CHAT_COMPLETIONS_CONSTRUCTOR_ARG_TYPES = new Class[] { String.class, OffsetDateTime.class, List.class, CompletionsUsage.class }; @@ -59,6 +59,10 @@ public class MergeUtils { private static final Class[] chatResponseMessageConstructorArgumentTypes = new Class[] { ChatRole.class, String.class }; + private MergeUtils() { + + } + /** * Create a new instance of the given class using the constructor at the given index. * Can be used to create instances with private constructors. diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java index 75ba720b02c..79662592cf6 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/aot/AzureOpenAiRuntimeHints.java @@ -50,8 +50,9 @@ public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader cla try { var resolver = new PathMatchingResourcePatternResolver(); - for (var resourceMatch : resolver.getResources("/azure-ai-openai.properties")) + for (var resourceMatch : resolver.getResources("/azure-ai-openai.properties")) { hints.resources().registerResource(resourceMatch); + } } catch (Exception e) { throw new RuntimeException(e); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index 8babb2e0c8c..7d08d6a0a03 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -39,7 +39,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.Resource; -import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; import static org.assertj.core.api.Assertions.assertThat; /** @@ -158,7 +157,8 @@ public OpenAIClientBuilder openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW) - .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS)); + .httpLogOptions(new HttpLogOptions() + .setLogLevel(com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS)); } @Bean diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index ffa657aa9f0..fe7325ef848 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -54,7 +54,6 @@ import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; -import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = AzureOpenAiChatModelIT.TestConfiguration.class) @@ -269,7 +268,8 @@ public OpenAIClientBuilder openAIClientBuilder() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW) - .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS)); + .httpLogOptions(new HttpLogOptions() + .setLogLevel(com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS)); } @Bean diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java index 2e194ea2d10..d3dadbbaf17 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java @@ -42,7 +42,6 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; import static org.assertj.core.api.Assertions.assertThat; /** @@ -190,7 +189,8 @@ public OpenAIClientBuilder openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW) - .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS)); + .httpLogOptions(new HttpLogOptions() + .setLogLevel(com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS)); } @Bean diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java index 48df6e123e8..08d6ae77f6b 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAiTestConfiguration.java @@ -223,7 +223,7 @@ public void setDispatcher(@Nullable Dispatcher dispatcher) { } protected Logger getLogger() { - return this.logger; + return logger; } @Override diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 19bc2c7308f..4ac53bb5c86 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -73,7 +73,7 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -97,7 +97,7 @@ void functionCallSequentialTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -119,7 +119,7 @@ void streamFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -156,7 +156,7 @@ void functionCallSequentialAndStreamTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the current weather in a given location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java index e122e5f690d..223a89f5d5b 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/MockWeatherService.java @@ -65,7 +65,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/models/spring-ai-bedrock/pom.xml b/models/spring-ai-bedrock/pom.xml index 29f48b7db01..a77188f9dae 100644 --- a/models/spring-ai-bedrock/pom.xml +++ b/models/spring-ai-bedrock/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java index 001b2fd9896..fb42f6b6f43 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java @@ -28,7 +28,7 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class MessageToPromptConverter { +public final class MessageToPromptConverter { private static final String HUMAN_PROMPT = "Human:"; diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java index 074c036352b..8c9d33032d5 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java @@ -200,11 +200,11 @@ public static Builder builder(String prompt) { return new Builder(prompt); } - public static class Builder { + public static final class Builder { private final String prompt; - private Double temperature;// = 0.7; - private Integer maxTokensToSample;// = 500; - private Integer topK;// = 10; + private Double temperature; // = 0.7; + private Integer maxTokensToSample; // = 500; + private Integer topK; // = 10; private Double topP; private List stopSequences; private String anthropicVersion; @@ -275,4 +275,4 @@ public record AnthropicChatResponse( } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index deaa01f1304..b6f32ab2e05 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -85,10 +85,11 @@ public ChatResponse call(Prompt prompt) { AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - List generations = response.content().stream().map(content -> { - return new Generation(new AssistantMessage(content.text()), - ChatGenerationMetadata.from(response.stopReason(), null)); - }).toList(); + List generations = response.content() + .stream() + .map(content -> new Generation(new AssistantMessage(content.text()), + ChatGenerationMetadata.from(response.stopReason(), null))) + .toList(); ChatResponseMetadata metadata = ChatResponseMetadata.builder() .withId(response.id()) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java index e6e8b96113f..39bb8899f0d 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java @@ -230,12 +230,12 @@ public static Builder builder(List messages) { return new Builder(messages); } - public static class Builder { + public static final class Builder { private final List messages; private String system; - private Double temperature;// = 0.7; - private Integer maxTokens;// = 500; - private Integer topK;// = 10; + private Double temperature; // = 0.7; + private Integer maxTokens; // = 500; + private Integer topK; // = 10; private Double topP; private List stopSequences; private String anthropicVersion; @@ -301,7 +301,8 @@ public AnthropicChatRequest build() { * responses. */ @JsonInclude(Include.NON_NULL) - public record MediaContent( // @formatter:off + public record MediaContent( + // @formatter:off @JsonProperty("type") Type type, @JsonProperty("source") Source source, @JsonProperty("text") String text, @@ -349,7 +350,8 @@ public enum Type { * @param data The base64-encoded data of the content. */ @JsonInclude(Include.NON_NULL) - public record Source( // @formatter:off + public record Source( + // @formatter:off @JsonProperty("type") String type, @JsonProperty("media_type") String mediaType, @JsonProperty("data") String data) { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java index b6f93d3b819..0ec06b678e9 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java @@ -52,43 +52,59 @@ public class BedrockRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(AbstractBedrockApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(AbstractBedrockApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(Ai21Jurassic2ChatBedrockApi.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(Ai21Jurassic2ChatBedrockApi.class)) { hints.reflection().registerType(tr, mcs); + } - for (var tr : findJsonAnnotatedClassesInPackage(CohereChatBedrockApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(CohereChatBedrockApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereChatOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereChatOptions.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(CohereEmbeddingBedrockApi.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(CohereEmbeddingBedrockApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class)) { hints.reflection().registerType(tr, mcs); + } - for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class)) { hints.reflection().registerType(tr, mcs); + } - for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanChatOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanChatOptions.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class)) { hints.reflection().registerType(tr, mcs); + } - for (var tr : findJsonAnnotatedClassesInPackage(AnthropicChatBedrockApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(AnthropicChatBedrockApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(AnthropicChatOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(AnthropicChatOptions.class)) { hints.reflection().registerType(tr, mcs); + } - for (var tr : findJsonAnnotatedClassesInPackage(Anthropic3ChatBedrockApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(Anthropic3ChatBedrockApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(Anthropic3ChatOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(Anthropic3ChatOptions.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index 11897200c28..e671e01bcb2 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// @formatter:off + package org.springframework.ai.bedrock.api; +// @formatter:off + import java.io.UncheckedIOException; import java.nio.charset.StandardCharsets; import java.time.Duration; @@ -103,7 +105,7 @@ public AbstractBedrockApi(String modelId, String region, Duration timeout) { * @param objectMapper The object mapper to use for JSON serialization and deserialization. */ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, - ObjectMapper objectMapper) { + ObjectMapper objectMapper) { this(modelId, credentialsProvider, region, objectMapper, Duration.ofMinutes(5)); } @@ -273,7 +275,7 @@ protected Flux internalInvocationStream(I request, Class clazz) { InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor .builder() - .onChunk((chunk) -> { + .onChunk(chunk -> { try { logger.debug("Received chunk: " + chunk.bytes().asString(StandardCharsets.UTF_8)); SO response = this.objectMapper.readValue(chunk.bytes().asByteArray(), clazz); @@ -284,7 +286,7 @@ protected Flux internalInvocationStream(I request, Class clazz) { eventSink.tryEmitError(e); } }) - .onDefault((event) -> { + .onDefault(event -> { logger.error("Unknown or unhandled event: " + event.toString()); eventSink.tryEmitError(new Throwable("Unknown or unhandled event: " + event.toString())); }) @@ -295,24 +297,20 @@ protected Flux internalInvocationStream(I request, Class clazz) { .onComplete( () -> { EmitResult emitResult = eventSink.tryEmitComplete(); - while(!emitResult.isSuccess()){ + while (!emitResult.isSuccess()) { System.out.println("Emitting complete:" + emitResult); emitResult = eventSink.tryEmitComplete(); - }; + } eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3))); // EmitResult emitResult = eventSink.tryEmitComplete(); logger.debug("\nCompleted streaming response."); }) - .onError((error) -> { + .onError(error -> { logger.error("\n\nError streaming response: " + error.getMessage()); eventSink.tryEmitError(error); }) - .onEventStream((stream) -> { - stream.subscribe( - (ResponseStream e) -> { - e.accept(visitor); - }); - }) + .onEventStream(stream -> stream.subscribe( + (ResponseStream e) -> e.accept(visitor))) .build(); this.clientStreaming.invokeModelWithResponseStream(invokeRequest, responseHandler); @@ -338,4 +336,4 @@ public record AmazonBedrockInvocationMetrics( @JsonProperty("invocationLatency") Long invocationLatency) { } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java index 4d235a2b264..f4d4c5e61e7 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.java @@ -61,9 +61,7 @@ public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOpt @Override public ChatResponse call(Prompt prompt) { CohereChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); - List generations = response.generations().stream().map(g -> { - return new Generation(g.text()); - }).toList(); + List generations = response.generations().stream().map(g -> new Generation(g.text())).toList(); return new ChatResponse(generations); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java index 8e5e0a6898e..3349545e1cc 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatOptions.java @@ -41,45 +41,54 @@ public class BedrockCohereChatOptions implements ChatOptions { * (optional) Use a lower value to decrease randomness in the response. Defaults to * 0.7. */ - @JsonProperty("temperature") Double temperature; + @JsonProperty("temperature") + Double temperature; /** * (optional) The maximum cumulative probability of tokens to consider when sampling. * The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers * the smallest set of tokens whose probability sum is at least topP. */ - @JsonProperty("p") Double topP; + @JsonProperty("p") + Double topP; /** * (optional) Specify the number of token choices the generative uses to generate the * next token. */ - @JsonProperty("k") Integer topK; + @JsonProperty("k") + Integer topK; /** * (optional) Specify the maximum number of tokens to use in the generated response. */ - @JsonProperty("max_tokens") Integer maxTokens; + @JsonProperty("max_tokens") + Integer maxTokens; /** * (optional) Configure up to four sequences that the generative recognizes. After a * stop sequence, the generative stops generating further tokens. The returned text * doesn't contain the stop sequence. */ - @JsonProperty("stop_sequences") List stopSequences; + @JsonProperty("stop_sequences") + List stopSequences; /** * (optional) Specify how and if the token likelihoods are returned with the response. */ - @JsonProperty("return_likelihoods") ReturnLikelihoods returnLikelihoods; + @JsonProperty("return_likelihoods") + ReturnLikelihoods returnLikelihoods; /** * (optional) The maximum number of generations that the generative should return. */ - @JsonProperty("num_generations") Integer numGenerations; + @JsonProperty("num_generations") + Integer numGenerations; /** * Prevents the model from generating unwanted tokens or incentivize the model to include desired tokens. */ - @JsonProperty("logit_bias") LogitBias logitBias; + @JsonProperty("logit_bias") + LogitBias logitBias; /** * (optional) Specifies how the API handles inputs longer than the maximum token * length. */ - @JsonProperty("truncate") Truncate truncate; + @JsonProperty("truncate") + Truncate truncate; // @formatter:on public static Builder builder() { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java index 454bd8aed4f..1d3fe8b5380 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// @formatter:off + package org.springframework.ai.bedrock.cohere.api; +// @formatter:off + import java.time.Duration; import java.util.List; @@ -412,4 +414,4 @@ public record TokenLikelihood( } } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java index e69f229d6d8..8fdf3114b69 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// @formatter:off + package org.springframework.ai.bedrock.cohere.api; +// @formatter:off + import java.time.Duration; import java.util.List; @@ -169,19 +171,23 @@ public enum InputType { * In search use-cases, use search_document when you encode documents for embeddings that you store in a * vector database. */ - @JsonProperty("search_document") SEARCH_DOCUMENT, + @JsonProperty("search_document") + SEARCH_DOCUMENT, /** * Use search_query when querying your vector DB to find relevant documents. */ - @JsonProperty("search_query") SEARCH_QUERY, + @JsonProperty("search_query") + SEARCH_QUERY, /** * Use classification when using embeddings as an input to a text classifier. */ - @JsonProperty("classification") CLASSIFICATION, + @JsonProperty("classification") + CLASSIFICATION, /** * Use clustering to cluster the embeddings. */ - @JsonProperty("clustering") CLUSTERING + @JsonProperty("clustering") + CLUSTERING } /** @@ -224,4 +230,4 @@ public record CohereEmbeddingResponse( } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java index 9fa7104cf37..e57a4ec5317 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// @formatter:off + package org.springframework.ai.bedrock.jurassic2.api; +// @formatter:off + import java.time.Duration; import java.util.List; @@ -409,4 +411,4 @@ public record FinishReason( } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java index 4a76ee485e3..0224defad21 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama/api/LlamaChatBedrockApi.java @@ -287,12 +287,14 @@ public enum StopReason { /** * The model has finished generating text for the input prompt. */ - @JsonProperty("stop") STOP, + @JsonProperty("stop") + STOP, /** * The response was truncated because of the response length you set. */ - @JsonProperty("length") LENGTH + @JsonProperty("length") + LENGTH } } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java index 1003ef0443b..43e20163939 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModel.java @@ -60,9 +60,10 @@ public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOption @Override public ChatResponse call(Prompt prompt) { TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt)); - List generations = response.results().stream().map(result -> { - return new Generation(result.outputText()); - }).toList(); + List generations = response.results() + .stream() + .map(result -> new Generation(result.outputText())) + .toList(); return new ChatResponse(generations); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java index c07527b0eb4..609a6244124 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java @@ -78,7 +78,7 @@ public float[] embed(Document document) { public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); if (request.getInstructions().size() != 1) { - this.logger.warn( + logger.warn( "Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java index 19e76729de0..b948a0666e7 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java @@ -236,7 +236,8 @@ public TitanChatRequest build() { if (this.temperature == null && this.topP == null && this.maxTokenCount == null && this.stopSequences == null) { return new TitanChatRequest(this.inputText, null); - } else { + } + else { return new TitanChatRequest(this.inputText, new TextGenerationConfig( this.temperature, diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java index b94ccff9a26..f322627a45c 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java @@ -102,7 +102,7 @@ public enum TitanEmbeddingModel { /** * amazon.titan-embed-text-v2 */ - TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0");; + TITAN_EMBED_TEXT_V2("amazon.titan-embed-text-v2:0"); private final String id; @@ -182,4 +182,4 @@ public record TitanEmbeddingResponse( } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java index 3638104cb63..179b7519113 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java @@ -35,8 +35,6 @@ import static org.assertj.core.api.Assertions.assertThat; -; - /** * @author Christian Tzolov */ @@ -70,7 +68,7 @@ public void chatCompletion() { assertThat(response.stop()).isEqualTo("\n\nHuman:"); assertThat(response.amazonBedrockInvocationMetrics()).isNull(); - this.logger.info("" + response); + logger.info("" + response); } @Test diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java index 55b054889c1..c69930419ed 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java @@ -38,7 +38,6 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION; /** * @author Ben Middleton @@ -62,12 +61,13 @@ public void chatCompletion() { .withTemperature(0.8) .withMaxTokens(300) .withTopK(10) - .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) + .withAnthropicVersion( + org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) .build(); AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - this.logger.info("" + response.content()); + logger.info("" + response.content()); assertThat(response).isNotNull(); assertThat(response.content().get(0).text()).isNotEmpty(); @@ -77,7 +77,7 @@ public void chatCompletion() { assertThat(response.usage().inputTokens()).isGreaterThan(10); assertThat(response.usage().outputTokens()).isGreaterThan(100); - this.logger.info("" + response); + logger.info("" + response); } @Test @@ -102,12 +102,13 @@ public void chatMultiCompletion() { .withTemperature(0.8) .withMaxTokens(400) .withTopK(10) - .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) + .withAnthropicVersion( + org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) .build(); AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); - this.logger.info("" + response.content()); + logger.info("" + response.content()); assertThat(response).isNotNull(); assertThat(response.content().get(0).text()).isNotEmpty(); assertThat(response.content().get(0).text()).contains("Blackbeard"); @@ -116,7 +117,7 @@ public void chatMultiCompletion() { assertThat(response.usage().inputTokens()).isGreaterThan(30); assertThat(response.usage().outputTokens()).isGreaterThan(200); - this.logger.info("" + response); + logger.info("" + response); } @Test @@ -128,7 +129,8 @@ public void chatCompletionStream() { .withTemperature(0.8) .withMaxTokens(300) .withTopK(10) - .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) + .withAnthropicVersion( + org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION) .build(); Flux responseStream = this.anthropicChatApi diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java index f3f33bbfc20..aa3ff8bdf3a 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHintsTests.java @@ -57,4 +57,4 @@ void registerHints() { } -} \ No newline at end of file +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java index 27c11af673a..41861e9ecac 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java @@ -35,8 +35,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -; - /** * @author Christian Tzolov */ diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java index b96991c38cb..40f674783e2 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatModelIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.bedrock.titan; import java.time.Duration; @@ -66,8 +67,8 @@ class BedrockTitanChatModelIT { @Test void multipleStreamAttempts() { - Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); - Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + Flux joke1Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = this.chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); String joke1 = joke1Stream.collectList() .block() @@ -96,10 +97,10 @@ void roleTest() { String name = "Bob"; String voice = "pirate"; UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = chatModel.call(prompt); + ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); } @@ -136,16 +137,13 @@ void mapOutputConverter() { Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); Map result = outputConverter.convert(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); } - record ActorsFilmsRecord(String actor, List movies) { - } - @Disabled("TODO: Fix the converter instructions to return the correct format") @Test void beanOutputConverterRecords() { @@ -160,7 +158,7 @@ void beanOutputConverterRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = chatModel.call(prompt).getResult(); + Generation generation = this.chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); @@ -182,7 +180,7 @@ void beanStreamOutputConverterRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatModel.stream(prompt) + String generationTextFromStream = this.chatModel.stream(prompt) .collectList() .block() .stream() @@ -215,4 +213,7 @@ public BedrockTitanChatModel titanChatModel(TitanChatBedrockApi titanApi) { } + record ActorsFilmsRecord(String actor, List movies) { + } + } diff --git a/models/spring-ai-huggingface/pom.xml b/models/spring-ai-huggingface/pom.xml index 9c74fb326b0..d0fe5583636 100644 --- a/models/spring-ai-huggingface/pom.xml +++ b/models/spring-ai-huggingface/pom.xml @@ -35,6 +35,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java index 9106ae98d37..c83978f3e83 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java @@ -44,15 +44,15 @@ void helloWorldCompletion() { address: #1 Samuel St. Just generate the JSON object without explanations: [/INST] - """; + """; Prompt prompt = new Prompt(mistral7bInstruct); ChatResponse chatResponse = this.huggingfaceChatModel.call(prompt); assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); String expectedResponse = """ { - "name": "John", - "lastname": "Smith", - "address": "#1 Samuel St." + "name": "John", + "lastname": "Smith", + "address": "#1 Samuel St." }"""; assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo(expectedResponse); assertThat(chatResponse.getResult().getOutput().getMetadata()).containsKey("generated_tokens"); diff --git a/models/spring-ai-minimax/pom.xml b/models/spring-ai-minimax/pom.xml index 85824eb4996..198f47112e5 100644 --- a/models/spring-ai-minimax/pom.xml +++ b/models/spring-ai-minimax/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 7d677db41e7..6f7913efd04 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -73,8 +73,6 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import static org.springframework.ai.minimax.api.MiniMaxApiConstants.TOOL_CALL_FUNCTION_TYPE; - /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax} * backed by {@link MiniMaxApi}. @@ -246,9 +244,10 @@ public ChatResponse call(Prompt prompt) { // @formatter:off // if the choice is a web search tool call, return last message of choice.messages ChatCompletionMessage message = null; - if(choice.message() != null) { + if (choice.message() != null) { message = choice.message(); - } else if(!CollectionUtils.isEmpty(choice.messages())){ + } + else if (!CollectionUtils.isEmpty(choice.messages())) { // the MiniMax web search messages result is ['user message','assistant tool call', 'tool call', 'assistant message'] // so the last message is the assistant message message = choice.messages().get(choice.messages().size() - 1); @@ -328,7 +327,8 @@ public Flux stream(Prompt prompt) { return buildGeneration(choice, metadata); }).toList(); return new ChatResponse(generations, from(chatCompletion2)); - } catch (Exception e) { + } + catch (Exception e) { logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } @@ -368,7 +368,8 @@ protected boolean isToolCall(Generation generation, Set toolCallFinishRe return generation.getOutput() .getToolCalls() .stream() - .anyMatch(toolCall -> TOOL_CALL_FUNCTION_TYPE.equals(toolCall.type())); + .anyMatch(toolCall -> org.springframework.ai.minimax.api.MiniMaxApiConstants.TOOL_CALL_FUNCTION_TYPE + .equals(toolCall.type())); } private ChatOptions buildRequestOptions(ChatCompletionRequest request) { @@ -456,9 +457,8 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; - toolMessage.getResponses().forEach(response -> { - Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"); - }); + toolMessage.getResponses() + .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); return toolMessage.getResponses() .stream() diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java index 01d7fb6206e..5a370bb6816 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/aot/MiniMaxRuntimeHints.java @@ -37,8 +37,9 @@ public class MiniMaxRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(MiniMaxApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(MiniMaxApi.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java index 3216f694056..ec81faf712b 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java @@ -104,7 +104,7 @@ public MiniMaxApi(String baseUrl, String miniMaxToken, RestClient.Builder restCl */ public MiniMaxApi(String baseUrl, String miniMaxToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { - Consumer authHeaders = (headers) -> { + Consumer authHeaders = headers -> { headers.setBearerAuth(miniMaxToken); headers.setContentType(MediaType.APPLICATION_JSON); }; @@ -252,27 +252,33 @@ public enum ChatCompletionFinishReason { /** * The model hit a natural stop point or a provided stop sequence. */ - @JsonProperty("stop") STOP, + @JsonProperty("stop") + STOP, /** * The maximum number of tokens specified in the request was reached. */ - @JsonProperty("length") LENGTH, + @JsonProperty("length") + LENGTH, /** * The content was omitted due to a flag from our content filters. */ - @JsonProperty("content_filter") CONTENT_FILTER, + @JsonProperty("content_filter") + CONTENT_FILTER, /** * The model called a tool. */ - @JsonProperty("tool_calls") TOOL_CALLS, + @JsonProperty("tool_calls") + TOOL_CALLS, /** * (deprecated) The model called a function. */ - @JsonProperty("function_call") FUNCTION_CALL, + @JsonProperty("function_call") + FUNCTION_CALL, /** * Only for compatibility with Mistral AI API. */ - @JsonProperty("tool_call") TOOL_CALL + @JsonProperty("tool_call") + TOOL_CALL } /** @@ -355,8 +361,10 @@ public enum Type { /** * Function tool type. */ - @JsonProperty("function") FUNCTION, - @JsonProperty("web_search") WEB_SEARCH + @JsonProperty("function") + FUNCTION, + @JsonProperty("web_search") + WEB_SEARCH } /** @@ -388,7 +396,7 @@ public Function(String description, String name, Map parameters) } } - /** + /** * Creates a model response for the given chat conversation. * * @param messages A list of messages comprising the conversation so far. @@ -428,7 +436,7 @@ public Function(String description, String name, Map parameters) * functions are present. Use the {@link ToolChoiceBuilder} to create the tool choice value. */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest ( + public record ChatCompletionRequest( @JsonProperty("messages") List messages, @JsonProperty("model") String model, @JsonProperty("frequency_penalty") Double frequencyPenalty, @@ -454,7 +462,7 @@ public record ChatCompletionRequest ( */ public ChatCompletionRequest(List messages, String model, Double temperature) { this(messages, model, null, null, null, null, - null, null, null, false, temperature, null,null, + null, null, null, false, temperature, null, null, null, null); } @@ -469,7 +477,7 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { this(messages, model, null, null, null, null, - null, null, null, stream, temperature, null,null, + null, null, null, stream, temperature, null, null, null, null); } @@ -485,11 +493,11 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { this(messages, model, null, null, null, null, - null, null, null, false, 0.8, null,null, + null, null, null, false, 0.8, null, null, tools, toolChoice); } - /** + /** * Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice. * Streaming is set to false, temperature to 0.8 and all other parameters are null. * @@ -499,7 +507,7 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, Boolean stream) { this(messages, null, null, null, null, null, - null, null, null, stream, null, null,null, + null, null, null, stream, null, null, null, null, null); } @@ -585,19 +593,23 @@ public enum Role { /** * System message. */ - @JsonProperty("system") SYSTEM, + @JsonProperty("system") + SYSTEM, /** * User message. */ - @JsonProperty("user") USER, + @JsonProperty("user") + USER, /** * Assistant message. */ - @JsonProperty("assistant") ASSISTANT, + @JsonProperty("assistant") + ASSISTANT, /** * Tool message. */ - @JsonProperty("tool") TOOL + @JsonProperty("tool") + TOOL } /** @@ -717,11 +729,10 @@ public record Choice( @JsonProperty("logprobs") LogProbs logprobs) { } - public record BaseResponse( @JsonProperty("status_code") Long statusCode, @JsonProperty("status_msg") String message - ){} + ) { } } /** diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java index c83d1a4486b..0f4152d40db 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApiConstants.java @@ -32,4 +32,8 @@ public final class MiniMaxApiConstants { public static final String PROVIDER_NAME = AiProvider.MINIMAX.value(); + private MiniMaxApiConstants() { + + } + } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java index 5c50ddf3095..e2ee4fb061d 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java @@ -70,7 +70,7 @@ public void promptOptionsTools() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName(TOOL_FUNCTION_NAME) .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build()), false); @@ -97,7 +97,7 @@ public void defaultOptionsTools() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName(TOOL_FUNCTION_NAME) .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build()); diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java index fbbf900667d..2dad6668d44 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxApiToolFunctionCallIT.java @@ -131,7 +131,7 @@ public void toolFunctionCall() { ResponseEntity chatCompletion2 = this.miniMaxApi.chatCompletionEntity(functionResponseRequest); - this.logger.info("Final response: " + chatCompletion2.getBody()); + logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java index d2099ef8457..610f97c08ce 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java @@ -51,7 +51,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Geng Rong @@ -87,13 +87,13 @@ public void miniMaxChatTransientError() { var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, new ChatCompletionMessage("Response", Role.ASSISTANT), null, null); - ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, + ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666L, "model", null, null, null, new MiniMaxApi.Usage(10, 10, 10)); - when(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + given(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); var result = this.chatModel.call(new Prompt("text")); @@ -105,8 +105,8 @@ public void miniMaxChatTransientError() { @Test public void miniMaxChatNonTransientError() { - when(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @@ -115,13 +115,13 @@ public void miniMaxChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, new ChatCompletionMessage("Response", Role.ASSISTANT), null); - ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666L, "model", null, null); - when(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(Flux.just(expectedChatCompletion)); + given(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(Flux.just(expectedChatCompletion)); var result = this.chatModel.stream(new Prompt("text")); @@ -133,8 +133,8 @@ public void miniMaxChatStreamTransientError() { @Test public void miniMaxChatStreamNonTransientError() { - when(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); } @@ -143,10 +143,10 @@ public void miniMaxEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(new float[] { 9.9f, 8.8f }), "model", 10); - when(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + given(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); @@ -159,8 +159,8 @@ public void miniMaxEmbeddingTransientError() { @Test public void miniMaxEmbeddingNonTransientError() { - when(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java index 0d3b164524c..49b973a2b72 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MockWeatherService.java @@ -65,7 +65,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java index 024de6f262a..fe083cd9de6 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java @@ -38,7 +38,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat; /** * @author Geng Rong @@ -93,7 +92,7 @@ void testWebSearch() { List functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool()); MiniMaxChatOptions options = MiniMaxChatOptions.builder() - .withModel(ABAB_6_5_S_Chat.value) + .withModel(org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.value) .withTools(functionTool) .build(); @@ -123,7 +122,7 @@ void testWebSearchStream() { List functionTool = List.of(MiniMaxApi.FunctionTool.webSearchFunctionTool()); MiniMaxChatOptions options = MiniMaxChatOptions.builder() - .withModel(ABAB_6_5_S_Chat.value) + .withModel(org.springframework.ai.minimax.api.MiniMaxApi.ChatModel.ABAB_6_5_S_Chat.value) .withTools(functionTool) .build(); diff --git a/models/spring-ai-mistral-ai/pom.xml b/models/spring-ai-mistral-ai/pom.xml index 5cae1c54039..b95c4a23feb 100644 --- a/models/spring-ai-mistral-ai/pom.xml +++ b/models/spring-ai-mistral-ai/pom.xml @@ -37,6 +37,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index c699962f933..25657ec39b1 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -178,7 +178,7 @@ public ChatResponse call(Prompt prompt) { ChatCompletion chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - this.logger.warn("No chat completion returned for prompt: {}", prompt); + logger.warn("No chat completion returned for prompt: {}", prompt); return new ChatResponse(List.of()); } @@ -266,7 +266,7 @@ public Flux stream(Prompt prompt) { } } catch (Exception e) { - this.logger.error("Error processing chat completion", e); + logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } })); @@ -284,9 +284,7 @@ public Flux stream(Prompt prompt) { } }) .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) + .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on; @@ -349,9 +347,8 @@ else if (message instanceof AssistantMessage assistantMessage) { } else if (message instanceof ToolResponseMessage toolResponseMessage) { - toolResponseMessage.getResponses().forEach(response -> { - Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"); - }); + toolResponseMessage.getResponses() + .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); return toolResponseMessage.getResponses() .stream() diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java index 6ad65d426c9..3bb55073807 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHints.java @@ -23,8 +23,6 @@ import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; -import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; - /** * The MistralAiRuntimeHints class is responsible for registering runtime hints for * Mistral AI API classes. @@ -37,8 +35,9 @@ public class MistralAiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(MistralAiApi.class)) + for (var tr : org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage(MistralAiApi.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index 41277a0c07f..ea225be4d8a 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -222,26 +222,32 @@ public Flux chatCompletionStream(ChatCompletionRequest chat public enum ChatCompletionFinishReason { // @formatter:off - /** - * The model hit a natural stop point or a provided stop sequence. - */ - @JsonProperty("stop") STOP, - /** - * The maximum number of tokens specified in the request was reached. - */ - @JsonProperty("length") LENGTH, - /** - * The content was omitted due to a flag from our content filters. - */ - @JsonProperty("model_length") MODEL_LENGTH, - /** - * - */ - @JsonProperty("error") ERROR, - /** - * The model requested a tool call. - */ - @JsonProperty("tool_calls") TOOL_CALLS + /** + * The model hit a natural stop point or a provided stop sequence. + */ + @JsonProperty("stop") + STOP, + + /** + * The maximum number of tokens specified in the request was reached. + */ + @JsonProperty("length") + LENGTH, + + /** + * The content was omitted due to a flag from our content filters. + */ + @JsonProperty("model_length") + MODEL_LENGTH, + + @JsonProperty("error") + ERROR, + + /** + * The model requested a tool call. + */ + @JsonProperty("tool_calls") + TOOL_CALLS // @formatter:on } @@ -258,18 +264,18 @@ public enum ChatCompletionFinishReason { public enum ChatModel implements ChatModelDescription { // @formatter:off - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B - TINY("open-mistral-7b"), - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MIXTRAL_7B - MIXTRAL("open-mixtral-8x7b"), - OPEN_MISTRAL_7B("open-mistral-7b"), - OPEN_MIXTRAL_7B("open-mixtral-8x7b"), - OPEN_MIXTRAL_22B("open-mixtral-8x22b"), - SMALL("mistral-small-latest"), - @Deprecated(since = "1.0.0-M1", forRemoval = true) // Mistral is removing this model - MEDIUM("mistral-medium-latest"), - LARGE("mistral-large-latest"); - // @formatter:on + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MISTRAL_7B + TINY("open-mistral-7b"), + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Replaced by OPEN_MIXTRAL_7B + MIXTRAL("open-mixtral-8x7b"), + OPEN_MISTRAL_7B("open-mistral-7b"), + OPEN_MIXTRAL_7B("open-mixtral-8x7b"), + OPEN_MIXTRAL_22B("open-mixtral-8x22b"), + SMALL("mistral-small-latest"), + @Deprecated(since = "1.0.0-M1", forRemoval = true) // Mistral is removing this model + MEDIUM("mistral-medium-latest"), + LARGE("mistral-large-latest"); + // @formatter:on private final String value; @@ -295,7 +301,7 @@ public String getName() { public enum EmbeddingModel { // @formatter:off - EMBED("mistral-embed"); + EMBED("mistral-embed"); // @formatter:on private final String value; @@ -383,9 +389,9 @@ public Function(String description, String name, String jsonSchema) { @JsonInclude(Include.NON_NULL) public record Usage( // @formatter:off - @JsonProperty("prompt_tokens") Integer promptTokens, - @JsonProperty("total_tokens") Integer totalTokens, - @JsonProperty("completion_tokens") Integer completionTokens) { + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("completion_tokens") Integer completionTokens) { // @formatter:on } @@ -400,9 +406,9 @@ public record Usage( @JsonInclude(Include.NON_NULL) public record Embedding( // @formatter:off - @JsonProperty("index") Integer index, - @JsonProperty("embedding") float[] embedding, - @JsonProperty("object") String object) { + @JsonProperty("index") Integer index, + @JsonProperty("embedding") float[] embedding, + @JsonProperty("object") String object) { // @formatter:on /** @@ -454,9 +460,9 @@ public String toString() { @JsonInclude(Include.NON_NULL) public record EmbeddingRequest( // @formatter:off - @JsonProperty("input") T input, - @JsonProperty("model") String model, - @JsonProperty("encoding_format") String encodingFormat) { + @JsonProperty("input") T input, + @JsonProperty("model") String model, + @JsonProperty("encoding_format") String encodingFormat) { // @formatter:on /** @@ -492,10 +498,10 @@ public EmbeddingRequest(T input) { @JsonInclude(Include.NON_NULL) public record EmbeddingList( // @formatter:off - @JsonProperty("object") String object, - @JsonProperty("data") List data, - @JsonProperty("model") String model, - @JsonProperty("usage") Usage usage) { + @JsonProperty("object") String object, + @JsonProperty("data") List data, + @JsonProperty("model") String model, + @JsonProperty("usage") Usage usage) { // @formatter:on } @@ -538,18 +544,18 @@ public record EmbeddingList( @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest( // @formatter:off - @JsonProperty("model") String model, - @JsonProperty("messages") List messages, - @JsonProperty("tools") List tools, - @JsonProperty("tool_choice") ToolChoice toolChoice, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("stream") Boolean stream, - @JsonProperty("safe_prompt") Boolean safePrompt, - @JsonProperty("stop") List stop, - @JsonProperty("random_seed") Integer randomSeed, - @JsonProperty("response_format") ResponseFormat responseFormat) { + @JsonProperty("model") String model, + @JsonProperty("messages") List messages, + @JsonProperty("tools") List tools, + @JsonProperty("tool_choice") ToolChoice toolChoice, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("stream") Boolean stream, + @JsonProperty("safe_prompt") Boolean safePrompt, + @JsonProperty("stop") List stop, + @JsonProperty("random_seed") Integer randomSeed, + @JsonProperty("response_format") ResponseFormat responseFormat) { // @formatter:on /** @@ -622,9 +628,12 @@ public ChatCompletionRequest(List messages, Boolean strea public enum ToolChoice { // @formatter:off - @JsonProperty("auto") AUTO, - @JsonProperty("any") ANY, - @JsonProperty("none") NONE + @JsonProperty("auto") + AUTO, + @JsonProperty("any") + ANY, + @JsonProperty("none") + NONE // @formatter:on } @@ -655,12 +664,17 @@ public record ResponseFormat(@JsonProperty("type") String type) { @JsonInclude(Include.NON_NULL) public record ChatCompletionMessage( // @formatter:off - @JsonProperty("content") String content, - @JsonProperty("role") Role role, - @JsonProperty("name") String name, - @JsonProperty("tool_calls") List toolCalls, - @JsonProperty("tool_call_id") String toolCallId) { - // @formatter:on + @JsonProperty("content") + String content, + @JsonProperty("role") + Role role, + @JsonProperty("name") + String name, + @JsonProperty("tool_calls") + List toolCalls, + @JsonProperty("tool_call_id") + String toolCallId) { + // @formatter:on /** * Message comprising the conversation. @@ -693,10 +707,14 @@ public ChatCompletionMessage(String content, Role role) { public enum Role { // @formatter:off - @JsonProperty("system") SYSTEM, - @JsonProperty("user") USER, - @JsonProperty("assistant") ASSISTANT, - @JsonProperty("tool") TOOL + @JsonProperty("system") + SYSTEM, + @JsonProperty("user") + USER, + @JsonProperty("assistant") + ASSISTANT, + @JsonProperty("tool") + TOOL // @formatter:on } @@ -746,12 +764,12 @@ public record ChatCompletionFunction(@JsonProperty("name") String name, @JsonInclude(Include.NON_NULL) public record ChatCompletion( // @formatter:off - @JsonProperty("id") String id, - @JsonProperty("object") String object, - @JsonProperty("created") Long created, - @JsonProperty("model") String model, - @JsonProperty("choices") List choices, - @JsonProperty("usage") Usage usage) { + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("choices") List choices, + @JsonProperty("usage") Usage usage) { // @formatter:on /** @@ -765,9 +783,9 @@ public record ChatCompletion( @JsonInclude(Include.NON_NULL) public record Choice( // @formatter:off - @JsonProperty("index") Integer index, - @JsonProperty("message") ChatCompletionMessage message, - @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("message") ChatCompletionMessage message, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } @@ -838,11 +856,11 @@ public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("lo @JsonInclude(Include.NON_NULL) public record ChatCompletionChunk( // @formatter:off - @JsonProperty("id") String id, - @JsonProperty("object") String object, - @JsonProperty("created") Long created, - @JsonProperty("model") String model, - @JsonProperty("choices") List choices) { + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("choices") List choices) { // @formatter:on /** @@ -856,9 +874,9 @@ public record ChatCompletionChunk( @JsonInclude(Include.NON_NULL) public record ChunkChoice( // @formatter:off - @JsonProperty("index") Integer index, - @JsonProperty("delta") ChatCompletionMessage delta, - @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ChatCompletionMessage delta, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, @JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 26b329137f4..6298c777393 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -106,7 +106,7 @@ void listOutputConverterString() { .user(u -> u.text("List five {subject}") .param("subject", "ice cream flavors")) .call() - .entity(new ParameterizedTypeReference>() {}); + .entity(new ParameterizedTypeReference<>() { }); // @formatter:on logger.info(collection.toString()); @@ -298,4 +298,4 @@ record ActorsFilms(String actor, List movies) { } -} \ No newline at end of file +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java index 9e3b2823020..59416998555 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java @@ -196,7 +196,7 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -219,7 +219,7 @@ void streamFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java index 1818d2e43c7..a547fcb18cb 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java @@ -50,7 +50,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -97,10 +97,10 @@ public void mistralAiChatTransientError() { ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model", List.of(choice), new MistralAiApi.Usage(10, 10, 10)); - when(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + given(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); var result = this.chatModel.call(new Prompt("text")); @@ -112,8 +112,8 @@ public void mistralAiChatTransientError() { @Test public void mistralAiChatNonTransientError() { - when(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @@ -126,10 +126,10 @@ public void mistralAiChatStreamTransientError() { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L, "model", List.of(choice)); - when(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(Flux.just(expectedChatCompletion)); + given(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(Flux.just(expectedChatCompletion)); var result = this.chatModel.stream(new Prompt("text")); @@ -142,8 +142,8 @@ public void mistralAiChatStreamTransientError() { @Test @Disabled("Currently stream() does not implement retry") public void mistralAiChatStreamNonTransientError() { - when(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text"))); } @@ -153,10 +153,10 @@ public void mistralAiEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new MistralAiApi.Usage(10, 10, 10)); - when(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + given(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); @@ -169,8 +169,8 @@ public void mistralAiEmbeddingTransientError() { @Test public void mistralAiEmbeddingNonTransientError() { - when(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java index 1e16b590cf4..ac20dbccb96 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/aot/MistralAiRuntimeHintsTests.java @@ -25,7 +25,6 @@ import org.springframework.aot.hint.TypeReference; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage; import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; class MistralAiRuntimeHintsTests { @@ -36,7 +35,8 @@ void registerHints() { MistralAiRuntimeHints mistralAiRuntimeHints = new MistralAiRuntimeHints(); mistralAiRuntimeHints.registerHints(runtimeHints, null); - Set jsonAnnotatedClasses = findJsonAnnotatedClassesInPackage(MistralAiApi.class); + Set jsonAnnotatedClasses = org.springframework.ai.aot.AiRuntimeHints + .findJsonAnnotatedClassesInPackage(MistralAiApi.class); for (TypeReference jsonAnnotatedClass : jsonAnnotatedClasses) { assertThat(runtimeHints).matches(reflection().onType(jsonAnnotatedClass)); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java index b4ea08ed0d4..2dcc5ca269e 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/MistralAiApiToolFunctionCallIT.java @@ -146,7 +146,7 @@ public void toolFunctionCall() throws JsonProcessingException { ResponseEntity chatCompletion2 = this.completionApi .chatCompletionEntity(functionResponseRequest); - this.logger.info("Final response: " + chatCompletion2.getBody()); + logger.info("Final response: " + chatCompletion2.getBody()); assertThat(chatCompletion2.getBody().choices()).isNotEmpty(); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java index dc144870183..c542df6fca5 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/tool/PaymentStatusFunctionCallingIT.java @@ -136,7 +136,7 @@ public void toolFunctionCall() throws JsonProcessingException { .chatCompletionEntity(new ChatCompletionRequest(messages, MistralAiApi.ChatModel.LARGE.getValue())); var responseContent = response.getBody().choices().get(0).message().content(); - this.logger.info("Final response: " + responseContent); + logger.info("Final response: " + responseContent); assertThat(responseContent).containsIgnoringCase("T1001"); assertThat(responseContent).containsIgnoringCase("Paid"); diff --git a/models/spring-ai-moonshot/pom.xml b/models/spring-ai-moonshot/pom.xml index 84a154850f7..5d72d858348 100644 --- a/models/spring-ai-moonshot/pom.xml +++ b/models/spring-ai-moonshot/pom.xml @@ -36,6 +36,11 @@ git@github.com:spring-projects/spring-ai.git + + + false + + diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java index 2751a9d6498..aa76d5fa9cc 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java @@ -365,9 +365,8 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; - toolMessage.getResponses().forEach(response -> { - Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"); - }); + toolMessage.getResponses() + .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); return toolMessage.getResponses() .stream() diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java index 7f8a3a27bd9..58c32304705 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/aot/MoonshotRuntimeHints.java @@ -36,8 +36,9 @@ public class MoonshotRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(MoonshotApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(MoonshotApi.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java index f6eb1c476c7..792f35b2dd8 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java @@ -40,8 +40,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import static org.springframework.ai.moonshot.api.MoonshotConstants.DEFAULT_BASE_URL; - /** * Single-class, Java Client library for Moonshot platform. Provides implementation for * the Chat Completion APIs. @@ -69,7 +67,7 @@ public class MoonshotApi { * @param moonshotApiKey Moonshot api Key. */ public MoonshotApi(String moonshotApiKey) { - this(DEFAULT_BASE_URL, moonshotApiKey); + this(org.springframework.ai.moonshot.api.MoonshotConstants.DEFAULT_BASE_URL, moonshotApiKey); } /** @@ -223,9 +221,9 @@ public enum ChatCompletionFinishReason { public enum ChatModel implements ChatModelDescription { // @formatter:off - MOONSHOT_V1_8K("moonshot-v1-8k"), - MOONSHOT_V1_32K("moonshot-v1-32k"), - MOONSHOT_V1_128K("moonshot-v1-128k"); + MOONSHOT_V1_8K("moonshot-v1-8k"), + MOONSHOT_V1_32K("moonshot-v1-32k"), + MOONSHOT_V1_128K("moonshot-v1-128k"); // @formatter:on private final String value; @@ -257,10 +255,10 @@ public String getName() { @JsonInclude(Include.NON_NULL) public record Usage( // @formatter:off - @JsonProperty("prompt_tokens") Integer promptTokens, - @JsonProperty("total_tokens") Integer totalTokens, - @JsonProperty("completion_tokens") Integer completionTokens) { - // @formatter:on + @JsonProperty("prompt_tokens") Integer promptTokens, + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("completion_tokens") Integer completionTokens) { + // @formatter:on } /** @@ -296,16 +294,16 @@ public record Usage( @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest( // @formatter:off - @JsonProperty("messages") List messages, - @JsonProperty("model") String model, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("n") Integer n, - @JsonProperty("frequency_penalty") Double frequencyPenalty, - @JsonProperty("presence_penalty") Double presencePenalty, - @JsonProperty("stop") List stop, - @JsonProperty("stream") Boolean stream, + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("n") Integer n, + @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("presence_penalty") Double presencePenalty, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, @JsonProperty("tools") List tools, @JsonProperty("tool_choice") Object toolChoice) { // @formatter:on @@ -517,12 +515,12 @@ public record ChatCompletionFunction(@JsonProperty("name") String name, @JsonInclude(Include.NON_NULL) public record ChatCompletion( // @formatter:off - @JsonProperty("id") String id, - @JsonProperty("object") String object, - @JsonProperty("created") Long created, - @JsonProperty("model") String model, - @JsonProperty("choices") List choices, - @JsonProperty("usage") Usage usage) { + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("choices") List choices, + @JsonProperty("usage") Usage usage) { // @formatter:on /** @@ -535,9 +533,9 @@ public record ChatCompletion( @JsonInclude(Include.NON_NULL) public record Choice( // @formatter:off - @JsonProperty("index") Integer index, - @JsonProperty("message") ChatCompletionMessage message, - @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { + @JsonProperty("index") Integer index, + @JsonProperty("message") ChatCompletionMessage message, + @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) { // @formatter:on } @@ -558,11 +556,11 @@ public record Choice( @JsonInclude(Include.NON_NULL) public record ChatCompletionChunk( // @formatter:off - @JsonProperty("id") String id, - @JsonProperty("object") String object, - @JsonProperty("created") Long created, - @JsonProperty("model") String model, - @JsonProperty("choices") List choices) { + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("created") Long created, + @JsonProperty("model") String model, + @JsonProperty("choices") List choices) { // @formatter:on /** diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java index 3d6bdd4b272..daefef59bbe 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotConstants.java @@ -27,4 +27,8 @@ public final class MoonshotConstants { public static final String PROVIDER_NAME = AiProvider.MOONSHOT.value(); + private MoonshotConstants() { + + } + } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java index e87a1122796..83e0f5d13f8 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java @@ -45,7 +45,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Geng Rong @@ -80,13 +80,13 @@ public void moonshotChatTransientError() { var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), ChatCompletionFinishReason.STOP); - ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model", + ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model", List.of(choice), new MoonshotApi.Usage(10, 10, 10)); - when(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + given(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); var result = this.chatModel.call(new Prompt("text")); @@ -98,8 +98,8 @@ public void moonshotChatTransientError() { @Test public void moonshotChatNonTransientError() { - when(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.moonshotApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @@ -108,13 +108,13 @@ public void moonshotChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), ChatCompletionFinishReason.STOP, null); - ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l, + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789L, "model", List.of(choice)); - when(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(Flux.just(expectedChatCompletion)); + given(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(Flux.just(expectedChatCompletion)); var result = this.chatModel.stream(new Prompt("text")); @@ -126,8 +126,8 @@ public void moonshotChatStreamTransientError() { @Test public void moonshotChatStreamNonTransientError() { - when(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java index 402409649ea..0c2685b42b3 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MockWeatherService.java @@ -65,7 +65,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } @@ -78,8 +78,8 @@ private Unit(String text) { @JsonClassDescription("Weather API request") public record Request(@JsonProperty(required = true, value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty("lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty("lon") @JsonPropertyDescription("The city longitude") double lon, @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java index fa2cc486346..c2597fc1da5 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java @@ -139,7 +139,7 @@ private void toolFunctionCall(String userMessage, String cityName) { ResponseEntity chatCompletion2 = this.moonshotApi.chatCompletionEntity(functionResponseRequest); - this.logger.info("Final response: " + chatCompletion2.getBody()); + logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java index 8fa54687b2f..d136686d1b5 100644 --- a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java @@ -66,7 +66,7 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -90,7 +90,7 @@ void streamFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); diff --git a/models/spring-ai-oci-genai/pom.xml b/models/spring-ai-oci-genai/pom.xml index 64a474969fd..82c34d0442d 100644 --- a/models/spring-ai-oci-genai/pom.xml +++ b/models/spring-ai-oci-genai/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java index b1f6da89b35..6ca87f2b114 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/BaseEmbeddingModelTest.java @@ -43,7 +43,7 @@ public class BaseEmbeddingModelTest { * Create an OCIEmbeddingModel instance using a config file authentication provider. * @return OCIEmbeddingModel instance */ - public static OCIEmbeddingModel get() { + public OCIEmbeddingModel get() { try { ConfigFileAuthenticationDetailsProvider authProvider = new ConfigFileAuthenticationDetailsProvider( CONFIG_FILE, PROFILE); diff --git a/models/spring-ai-ollama/pom.xml b/models/spring-ai-ollama/pom.xml index 3b9a4428a7b..0e9fb5b942a 100644 --- a/models/spring-ai-ollama/pom.xml +++ b/models/spring-ai-ollama/pom.xml @@ -30,10 +30,11 @@ Spring AI Model - Ollama Ollama models support - + 17 17 UTF-8 + false diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index c60523c00ee..f4fcd722f15 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -236,9 +236,9 @@ public Flux stream(Prompt prompt) { } }) .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) + .doFinally(s -> + observation.stop() + ) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on @@ -392,7 +392,7 @@ public void setObservationConvention(ChatModelObservationConvention observationC this.observationConvention = observationConvention; } - public static class Builder { + public static final class Builder { private OllamaApi ollamaApi; diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index f44c9c6ea40..1809100321c 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -211,7 +211,7 @@ public static Duration parse(String input) { } - public static class Builder { + public static final class Builder { private OllamaApi ollamaApi; diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java index bd8799c9b8b..057df2376a2 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/aot/OllamaRuntimeHints.java @@ -37,10 +37,12 @@ public class OllamaRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(OllamaApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(OllamaApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(OllamaOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(OllamaOptions.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index bbd32c5117b..3b6c2afced0 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -114,7 +114,7 @@ public OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient @Deprecated(since = "1.0.0-M2", forRemoval = true) public GenerateResponse generate(GenerateRequest completionRequest) { Assert.notNull(completionRequest, REQUEST_BODY_NULL_ERROR); - Assert.isTrue(completionRequest.stream() == false, "Stream mode must be disabled."); + Assert.isTrue(!completionRequest.stream(), "Stream mode must be disabled."); return this.restClient.post() .uri("/api/generate") @@ -535,19 +535,23 @@ public enum Role { /** * System message type used as instructions to the model. */ - @JsonProperty("system") SYSTEM, + @JsonProperty("system") + SYSTEM, /** * User message type. */ - @JsonProperty("user") USER, + @JsonProperty("user") + USER, /** * Assistant message type. Usually the response from the model. */ - @JsonProperty("assistant") ASSISTANT, + @JsonProperty("assistant") + ASSISTANT, /** * Tool message. */ - @JsonProperty("tool") TOOL + @JsonProperty("tool") + TOOL } @@ -666,7 +670,8 @@ public enum Type { /** * Function tool type. */ - @JsonProperty("function") FUNCTION + @JsonProperty("function") + FUNCTION } /** @@ -900,13 +905,13 @@ public record Details( @JsonProperty("families") List families, @JsonProperty("parameter_size") String parameterSize, @JsonProperty("quantization_level") String quantizationLevel - ) {} + ) { } } @JsonInclude(Include.NON_NULL) public record ListModelResponse( @JsonProperty("models") List models - ) {} + ) { } @JsonInclude(Include.NON_NULL) public record ShowModelRequest( @@ -932,18 +937,18 @@ public record ShowModelResponse( @JsonProperty("model_info") Map modelInfo, @JsonProperty("projector_info") Map projectorInfo, @JsonProperty("modified_at") Instant modifiedAt - ) {} + ) { } @JsonInclude(Include.NON_NULL) public record CopyModelRequest( @JsonProperty("source") String source, @JsonProperty("destination") String destination - ) {} + ) { } @JsonInclude(Include.NON_NULL) public record DeleteModelRequest( @JsonProperty("model") String model - ) {} + ) { } @JsonInclude(Include.NON_NULL) public record PullModelRequest( @@ -971,7 +976,7 @@ public record ProgressResponse( @JsonProperty("digest") String digest, @JsonProperty("total") Long total, @JsonProperty("completed") Long completed - ) {} + ) { } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 034a4b75c66..16f06833173 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -62,24 +62,28 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed /** * Whether to use NUMA. (Default: false) */ - @JsonProperty("numa") private Boolean useNUMA; + @JsonProperty("numa") + private Boolean useNUMA; /** * Sets the size of the context window used to generate the next token. (Default: 2048) */ - @JsonProperty("num_ctx") private Integer numCtx; + @JsonProperty("num_ctx") + private Integer numCtx; /** * Prompt processing maximum batch size. (Default: 512) */ - @JsonProperty("num_batch") private Integer numBatch; + @JsonProperty("num_batch") + private Integer numBatch; /** * The number of layers to send to the GPU(s). On macOS, it defaults to 1 * to enable metal support, 0 to disable. * (Default: -1, which indicates that numGPU should be set dynamically) */ - @JsonProperty("num_gpu") private Integer numGPU; + @JsonProperty("num_gpu") + private Integer numGPU; /** * When using multiple GPUs this option controls which GPU is used @@ -88,28 +92,33 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed * more VRAM to store a scratch buffer for temporary results. * By default, GPU 0 is used. */ - @JsonProperty("main_gpu")private Integer mainGPU; + @JsonProperty("main_gpu") + private Integer mainGPU; /** * (Default: false) */ - @JsonProperty("low_vram") private Boolean lowVRAM; + @JsonProperty("low_vram") + private Boolean lowVRAM; /** * (Default: true) */ - @JsonProperty("f16_kv") private Boolean f16KV; + @JsonProperty("f16_kv") + private Boolean f16KV; /** * Return logits for all the tokens, not just the last one. * To enable completions to return logprobs, this must be true. */ - @JsonProperty("logits_all") private Boolean logitsAll; + @JsonProperty("logits_all") + private Boolean logitsAll; /** * Load only the vocabulary, not the weights. */ - @JsonProperty("vocab_only") private Boolean vocabOnly; + @JsonProperty("vocab_only") + private Boolean vocabOnly; /** * By default, models are mapped into memory, which allows the system to load only the necessary parts @@ -120,7 +129,8 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed * the model from loading at all. * (Default: null) */ - @JsonProperty("use_mmap") private Boolean useMMap; + @JsonProperty("use_mmap") + private Boolean useMMap; /** * Lock the model in memory, preventing it from being swapped out when memory-mapped. @@ -128,7 +138,8 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed * by requiring more RAM to run and potentially slowing down load times as the model loads into RAM. * (Default: false) */ - @JsonProperty("use_mlock") private Boolean useMLock; + @JsonProperty("use_mlock") + private Boolean useMLock; /** * Set the number of threads to use during generation. For optimal performance, it is recommended to set this value @@ -136,113 +147,131 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed * Using the correct number of threads can greatly improve performance. * By default, Ollama will detect this value for optimal performance. */ - @JsonProperty("num_thread") private Integer numThread; + @JsonProperty("num_thread") + private Integer numThread; // Following fields are predict options used at runtime. /** * (Default: 4) */ - @JsonProperty("num_keep") private Integer numKeep; + @JsonProperty("num_keep") + private Integer numKeep; /** * Sets the random number seed to use for generation. Setting this to a * specific number will make the model generate the same text for the same prompt. * (Default: -1) */ - @JsonProperty("seed") private Integer seed; + @JsonProperty("seed") + private Integer seed; /** * Maximum number of tokens to predict when generating text. * (Default: 128, -1 = infinite generation, -2 = fill context) */ - @JsonProperty("num_predict") private Integer numPredict; + @JsonProperty("num_predict") + private Integer numPredict; /** * Reduces the probability of generating nonsense. A higher value (e.g. * 100) will give more diverse answers, while a lower value (e.g. 10) will be more * conservative. (Default: 40) */ - @JsonProperty("top_k") private Integer topK; + @JsonProperty("top_k") + private Integer topK; /** * Works together with top-k. A higher value (e.g., 0.95) will lead to * more diverse text, while a lower value (e.g., 0.5) will generate more focused and * conservative text. (Default: 0.9) */ - @JsonProperty("top_p") private Double topP; + @JsonProperty("top_p") + private Double topP; /** * Tail free sampling is used to reduce the impact of less probable tokens * from the output. A higher value (e.g., 2.0) will reduce the impact more, while a * value of 1.0 disables this setting. (default: 1) */ - @JsonProperty("tfs_z") private Float tfsZ; + @JsonProperty("tfs_z") + private Float tfsZ; /** * (Default: 1.0) */ - @JsonProperty("typical_p") private Float typicalP; + @JsonProperty("typical_p") + private Float typicalP; /** * Sets how far back for the model to look back to prevent * repetition. (Default: 64, 0 = disabled, -1 = num_ctx) */ - @JsonProperty("repeat_last_n") private Integer repeatLastN; + @JsonProperty("repeat_last_n") + private Integer repeatLastN; /** * The temperature of the model. Increasing the temperature will * make the model answer more creatively. (Default: 0.8) */ - @JsonProperty("temperature") private Double temperature; + @JsonProperty("temperature") + private Double temperature; /** * Sets how strongly to penalize repetitions. A higher value * (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., * 0.9) will be more lenient. (Default: 1.1) */ - @JsonProperty("repeat_penalty") private Double repeatPenalty; + @JsonProperty("repeat_penalty") + private Double repeatPenalty; /** * (Default: 0.0) */ - @JsonProperty("presence_penalty") private Double presencePenalty; + @JsonProperty("presence_penalty") + private Double presencePenalty; /** * (Default: 0.0) */ - @JsonProperty("frequency_penalty") private Double frequencyPenalty; + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; /** * Enable Mirostat sampling for controlling perplexity. (default: 0, 0 * = disabled, 1 = Mirostat, 2 = Mirostat 2.0) */ - @JsonProperty("mirostat") private Integer mirostat; + @JsonProperty("mirostat") + private Integer mirostat; /** * Controls the balance between coherence and diversity of the output. * A lower value will result in more focused and coherent text. (Default: 5.0) */ - @JsonProperty("mirostat_tau") private Float mirostatTau; + @JsonProperty("mirostat_tau") + private Float mirostatTau; /** * Influences how quickly the algorithm responds to feedback from the generated text. * A lower learning rate will result in slower adjustments, while a higher learning rate * will make the algorithm more responsive. (Default: 0.1) */ - @JsonProperty("mirostat_eta") private Float mirostatEta; + @JsonProperty("mirostat_eta") + private Float mirostatEta; /** * (Default: true) */ - @JsonProperty("penalize_newline") private Boolean penalizeNewline; + @JsonProperty("penalize_newline") + private Boolean penalizeNewline; /** * Sets the stop sequences to use. When this pattern is encountered the * LLM will stop generating text and return. Multiple stop patterns may be set by * specifying multiple separate stop parameters in a modelfile. */ - @JsonProperty("stop") private List stop; + @JsonProperty("stop") + private List stop; // Following fields are not part of the Ollama Options API but part of the Request. @@ -252,27 +281,30 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed * Used to allow overriding the model name with prompt options. * Part of Chat completion parameters. */ - @JsonProperty("model") private String model; + @JsonProperty("model") + private String model; /** * Sets the desired format of output from the LLM. The only valid values are null or "json". * Part of Chat completion advanced parameters. */ - @JsonProperty("format") private String format; + @JsonProperty("format") + private String format; /** * Sets the length of time for Ollama to keep the model loaded. Valid values for this * setting are parsed by ParseDuration in Go. * Part of Chat completion advanced parameters. */ - @JsonProperty("keep_alive") private String keepAlive; - - + @JsonProperty("keep_alive") + private String keepAlive; + /** * Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded. * Defaults to true. */ - @JsonProperty("truncate") private Boolean truncate; + @JsonProperty("truncate") + private Boolean truncate; /** * Tool Function Callbacks to register with the ChatModel. @@ -314,7 +346,7 @@ public static OllamaOptions builder() { public static OllamaOptions create() { return new OllamaOptions(); } - + /** * Filter out the non-supported fields from the options. * @param options The options to filter. @@ -718,8 +750,8 @@ public void setSeed(Integer seed) { @Override @JsonIgnore public Integer getMaxTokens() { - return getNumPredict(); - } + return getNumPredict(); + } @JsonIgnore public void setMaxTokens(Integer maxTokens) { diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java index 572ec5896c3..ee06f934530 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java @@ -81,13 +81,13 @@ private String normalizeModelName(String modelName) { } public void deleteModel(String modelName) { - this.logger.info("Start deletion of model: {}", modelName); + logger.info("Start deletion of model: {}", modelName); if (!isModelAvailable(modelName)) { - this.logger.info("Model {} not found", modelName); + logger.info("Model {} not found", modelName); return; } this.ollamaApi.deleteModel(new DeleteModelRequest(modelName)); - this.logger.info("Completed deletion of model: {}", modelName); + logger.info("Completed deletion of model: {}", modelName); } public void pullModel(String modelName) { @@ -101,27 +101,27 @@ public void pullModel(String modelName, PullModelStrategy pullModelStrategy) { if (PullModelStrategy.WHEN_MISSING.equals(pullModelStrategy)) { if (isModelAvailable(modelName)) { - this.logger.debug("Model '{}' already available. Skipping pull operation.", modelName); + logger.debug("Model '{}' already available. Skipping pull operation.", modelName); return; } } // @formatter:off - this.logger.info("Start pulling model: {}", modelName); + logger.info("Start pulling model: {}", modelName); this.ollamaApi.pullModel(new PullModelRequest(modelName)) .bufferUntilChanged(OllamaApi.ProgressResponse::status) .doOnEach(signal -> { var progressResponses = signal.get(); if (!CollectionUtils.isEmpty(progressResponses) && progressResponses.get(progressResponses.size() - 1) != null) { - this.logger.info("Pulling the '{}' model - Status: {}", modelName, progressResponses.get(progressResponses.size() - 1).status()); + logger.info("Pulling the '{}' model - Status: {}", modelName, progressResponses.get(progressResponses.size() - 1).status()); } }) .takeUntil(progressResponses -> progressResponses.get(0) != null && progressResponses.get(0).status().equals("success")) .timeout(this.options.timeout()) .retryWhen(Retry.backoff(this.options.maxRetries(), Duration.ofSeconds(5))) .blockLast(); - this.logger.info("Completed pulling the '{}' model", modelName); + logger.info("Completed pulling the '{}' model", modelName); // @formatter:on } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java index e6f021008a0..c2c75700a88 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/PullModelStrategy.java @@ -39,6 +39,6 @@ public enum PullModelStrategy { /** * Never pull the model. */ - NEVER; + NEVER } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java index f58413f2640..1944621f6d8 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/BaseOllamaIT.java @@ -1,3 +1,19 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.ollama; import java.time.Duration; @@ -34,7 +50,7 @@ public class BaseOllamaIT { * * to the file ".testcontainers.properties" located in your home directory */ - public static boolean isDisabled() { + public boolean isDisabled() { return false; } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index 5738c337f8b..c88e83db17c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -71,7 +71,7 @@ void functionCallTest() { .withName("getCurrentWeather") .withDescription( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -96,7 +96,7 @@ void streamFunctionCallTest() { .withName("getCurrentWeather") .withDescription( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java index 0cb22784c05..735857ef3a1 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java @@ -37,7 +37,7 @@ import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertThrows; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @SpringBootTest @Testcontainers @@ -58,7 +58,8 @@ void unsupportedMediaType() { var userMessage = new UserMessage("Explain what do you see in this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); - assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt(List.of(userMessage)))); + assertThatThrownBy(() -> this.chatModel.call(new Prompt(List.of(userMessage)))) + .isInstanceOf(RuntimeException.class); } @Test diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java index 0afb8c24755..f8b9179e106 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java @@ -37,7 +37,7 @@ import org.springframework.ai.ollama.api.OllamaOptions; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov @@ -56,10 +56,10 @@ public class OllamaEmbeddingModelTests { @Test public void options() { - when(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture())) - .thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME", + given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture())) + .willReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(new float[] { 1f, 2f, 3f }, new float[] { 4f, 5f, 6f }), 0L, 0L, 0)) - .thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2", + .willReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2", List.of(new float[] { 7f, 8f, 9f }, new float[] { 10f, 11f, 12f }), 0L, 0L, 0)); // Tests default options diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 1e2bf625fc5..950d24d3b42 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class OllamaImage { +public final class OllamaImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.14"); + private OllamaImage() { + + } + } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java index c732a8e5ed1..8b2a7a74dd6 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/tool/MockWeatherService.java @@ -65,7 +65,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/models/spring-ai-postgresml/pom.xml b/models/spring-ai-postgresml/pom.xml index 0312ebd4f4c..d205720f1fa 100644 --- a/models/spring-ai-postgresml/pom.xml +++ b/models/spring-ai-postgresml/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + org.springframework.ai diff --git a/models/spring-ai-qianfan/pom.xml b/models/spring-ai-qianfan/pom.xml index 379d29eb26c..750fe69fc56 100644 --- a/models/spring-ai-qianfan/pom.xml +++ b/models/spring-ai-qianfan/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java index aaf68884c92..944a1c4e1f0 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java @@ -223,9 +223,7 @@ public Flux stream(Prompt prompt) { return new ChatResponse(generations, from(chatCompletion, request.model())); })) .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) + .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); return new MessageAggregator().aggregate(chatResponse, observationContext::setResponse); diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java index d102d34feb3..9d49ab908ec 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanImageOptions.java @@ -31,7 +31,7 @@ * @since 1.0 */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class QianFanImageOptions implements ImageOptions { +public final class QianFanImageOptions implements ImageOptions { /** * The number of images to generate. Must be between 1 and 4. @@ -188,7 +188,7 @@ public String toString() { + ", user='" + this.user + '\'' + '}'; } - public static class Builder { + public static final class Builder { private final QianFanImageOptions options; diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java index 2538e4f8b20..1de207fbf48 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/aot/QianFanRuntimeHints.java @@ -37,10 +37,12 @@ public class QianFanRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(QianFanApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(QianFanApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(QianFanImageApi.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(QianFanImageApi.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java index da93b16b67e..4139abd0f45 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java @@ -112,7 +112,7 @@ public QianFanApi(String baseUrl, String apiKey, String secretKey, RestClient.Bu * @param responseErrorHandler Response error handler. */ public QianFanApi(String baseUrl, String apiKey, String secretKey, RestClient.Builder restClientBuilder, - WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { super(apiKey, secretKey); this.restClient = restClientBuilder @@ -139,7 +139,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); return this.restClient.post() - .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}", chatRequest.model, getAccessToken()) .body(chatRequest) .retrieve() .toEntity(ChatCompletion.class); @@ -156,7 +156,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); return this.webClient.post() - .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}",chatRequest.model, getAccessToken()) + .uri("/v1/wenxinworkshop/chat/{model}?access_token={token}", chatRequest.model, getAccessToken()) .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(ChatCompletionChunk.class) @@ -287,7 +287,7 @@ public String getValue() { * probability mass are considered. We generally recommend altering this or temperature but not both. */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest ( + public record ChatCompletionRequest( @JsonProperty("messages") List messages, @JsonProperty("system") String system, @JsonProperty("model") String model, @@ -308,7 +308,7 @@ public record ChatCompletionRequest ( * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String system, String model, Double temperature) { - this(messages, system, model, null,null, + this(messages, system, model, null, null, null, null, null, false, temperature, null); } @@ -322,7 +322,7 @@ public ChatCompletionRequest(List messages, String system * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, String system, String model, Double temperature, boolean stream) { - this(messages, system, model, null,null, + this(messages, system, model, null, null, null, null, null, stream, temperature, null); } @@ -336,7 +336,7 @@ public ChatCompletionRequest(List messages, String system * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, String system, Boolean stream) { - this(messages, system, DEFAULT_CHAT_MODEL, null,null, + this(messages, system, DEFAULT_CHAT_MODEL, null, null, null, null, null, stream, 0.8, null); } @@ -382,15 +382,18 @@ public enum Role { /** * System message. */ - @JsonProperty("system") SYSTEM, + @JsonProperty("system") + SYSTEM, /** * User message. */ - @JsonProperty("user") USER, + @JsonProperty("user") + USER, /** * Assistant message. */ - @JsonProperty("assistant") ASSISTANT + @JsonProperty("assistant") + ASSISTANT } } @@ -483,7 +486,7 @@ public EmbeddingRequest(String text) { * @param userId A unique identifier representing your end-user, which can help QianFan to * monitor and detect abuse. */ - public EmbeddingRequest(String text,String model,String userId) { + public EmbeddingRequest(String text, String model, String userId) { this(List.of(text), model, userId); } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java index 5dd2744f768..e5e19d656c2 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java @@ -25,10 +25,14 @@ * @author Geng Rong * @since 1.0 */ -public class QianFanConstants { +public final class QianFanConstants { public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com/rpc/2.0/ai_custom"; public static final String PROVIDER_NAME = AiProvider.QIANFAN.value(); + private QianFanConstants() { + + } + } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java index 2532e52df0e..c094b96b790 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java @@ -114,7 +114,7 @@ public String getValue() { // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) - public record QianFanImageRequest ( + public record QianFanImageRequest( @JsonProperty("model") String model, @JsonProperty("prompt") String prompt, @JsonProperty("negative_prompt") String negativePrompt, diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java index fb1e9723b00..33c91dd5bf1 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanUtils.java @@ -21,10 +21,14 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -public class QianFanUtils { +public final class QianFanUtils { public static Consumer defaultHeaders() { return headers -> headers.setContentType(MediaType.APPLICATION_JSON); } + private QianFanUtils() { + + } + } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java index ec29676eb69..bc5193e94b7 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/auth/QianFanAccessToken.java @@ -86,4 +86,4 @@ private long getCurrentTimeInSeconds() { return System.currentTimeMillis() / 1000L; } -} \ No newline at end of file +} diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java index f8dae1f2092..b4df7b8fa31 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanApiIT.java @@ -39,7 +39,7 @@ /** * @author Geng Rong */ -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) public class QianFanApiIT { diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java index 978eb7216cc..36d15fd6edf 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/api/QianFanRetryTests.java @@ -58,7 +58,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Geng Rong @@ -96,10 +96,10 @@ public void qianFanChatTransientError() { ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 666L, "Response", "STOP", new Usage(10, 10, 10)); - when(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + given(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); var result = this.chatClient.call(new Prompt("text")); @@ -111,8 +111,8 @@ public void qianFanChatTransientError() { @Test public void qianFanChatNonTransientError() { - when(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.qianFanApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatClient.call(new Prompt("text"))); } @@ -122,10 +122,10 @@ public void qianFanChatStreamTransientError() { ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion", 666L, "Response", "", true, null); - when(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(Flux.just(expectedChatCompletion)); + given(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(Flux.just(expectedChatCompletion)); var result = this.chatClient.stream(new Prompt("text")); @@ -138,8 +138,8 @@ public void qianFanChatStreamTransientError() { @Test public void qianFanChatStreamNonTransientError() { - when(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.qianFanApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatClient.stream(new Prompt("text")).collectList().block()); } @@ -149,10 +149,10 @@ public void qianFanEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList("embedding_list", List.of(embedding), "model", null, null, new Usage(10, 10, 10)); - when(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + given(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); var result = this.embeddingClient .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); @@ -165,8 +165,8 @@ public void qianFanEmbeddingTransientError() { @Test public void qianFanEmbeddingNonTransientError() { - when(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.qianFanApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.embeddingClient .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -176,10 +176,10 @@ public void qianFanImageTransientError() { var expectedResponse = new QianFanImageResponse("1", 678L, List.of(new Data(1, "b64"))); - when(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); + given(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedResponse))); var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); @@ -191,8 +191,8 @@ public void qianFanImageTransientError() { @Test public void qianFanImageNonTransientError() { - when(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) - .thenThrow(new RuntimeException("Transient Error 1")); + given(this.qianFanImageApi.createImage(isA(QianFanImageRequest.class))) + .willThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java index 46c4c67e5da..f8bb6ef6995 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelIT.java @@ -46,7 +46,7 @@ * @author Geng Rong */ @SpringBootTest(classes = QianFanTestConfiguration.class) -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) class QianFanChatModelIT { diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java index 4b447ffa36a..9d0ebc6d98d 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/chat/QianFanChatModelObservationIT.java @@ -52,7 +52,7 @@ * @author Geng Rong */ @SpringBootTest(classes = QianFanChatModelObservationIT.Config.class) -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) public class QianFanChatModelObservationIT { diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java index 371d5c2a833..4fc1f825619 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/EmbeddingIT.java @@ -35,7 +35,7 @@ * @author Geng Rong */ @SpringBootTest(classes = QianFanTestConfiguration.class) -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) class EmbeddingIT { diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java index 5061626a751..83bc0748fe8 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/embedding/QianFanEmbeddingModelObservationIT.java @@ -50,7 +50,7 @@ * @author Geng Rong */ @SpringBootTest(classes = QianFanEmbeddingModelObservationIT.Config.class) -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) public class QianFanEmbeddingModelObservationIT { diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java index 6e4be44c737..640bad07a09 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelIT.java @@ -36,7 +36,7 @@ * @author Geng Rong */ @SpringBootTest(classes = QianFanTestConfiguration.class) -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) public class QianFanImageModelIT { diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java index 3ddaf41e586..885a51b9e69 100644 --- a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/image/QianFanImageModelObservationIT.java @@ -46,7 +46,7 @@ * @author Geng Rong */ @SpringBootTest(classes = QianFanImageModelObservationIT.Config.class) -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) public class QianFanImageModelObservationIT { diff --git a/models/spring-ai-stability-ai/pom.xml b/models/spring-ai-stability-ai/pom.xml index d60bc443303..b1a14b4c93f 100644 --- a/models/spring-ai-stability-ai/pom.xml +++ b/models/spring-ai-stability-ai/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java index f3d76b3faf9..8e548f3fe5a 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java @@ -52,4 +52,4 @@ public String toString() { return this.text; } -} \ No newline at end of file +} diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java index 645e13f1ab3..46d4f720f8e 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiImageOptions.java @@ -426,7 +426,7 @@ public String toString() { + this.stylePreset + '\'' + '}'; } - public static class Builder { + public static final class Builder { private final StabilityAiImageOptions options; diff --git a/models/spring-ai-transformers/pom.xml b/models/spring-ai-transformers/pom.xml index f32266818f4..3351485ed25 100644 --- a/models/spring-ai-transformers/pom.xml +++ b/models/spring-ai-transformers/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java index a074571827f..6ad3aa290ff 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/ResourceCacheService.java @@ -146,4 +146,4 @@ public void deleteCacheFolder() { } } -} \ No newline at end of file +} diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java index 8119bbca5bf..90380481749 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/samples/ONNXSample.java @@ -35,7 +35,11 @@ // https://www.sbert.net/examples/applications/computing-embeddings/README.html#sentence-embeddings-with-transformers -public class ONNXSample { +public final class ONNXSample { + + private ONNXSample() { + + } public static NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask) { diff --git a/models/spring-ai-vertex-ai-embedding/pom.xml b/models/spring-ai-vertex-ai-embedding/pom.xml index 0ce34354e8f..dc61e8dfb7e 100644 --- a/models/spring-ai-vertex-ai-embedding/pom.xml +++ b/models/spring-ai-vertex-ai-embedding/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java index 750d9816a0b..b5d9ae0b5be 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingModelName.java @@ -70,4 +70,4 @@ public String getDescription() { return this.description; } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java index 89762581c52..98bc4a73548 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/multimodal/VertexAiMultimodalEmbeddingOptions.java @@ -85,13 +85,12 @@ public class VertexAiMultimodalEmbeddingOptions implements EmbeddingOptions { */ private @JsonProperty("videoStartOffsetSec") Integer videoStartOffsetSec; - /** * The end offset of the video segment in seconds. If not specified, it's calculated with min(video length, startOffSec + 120). * If both startOffSec and endOffSec are specified, endOffsetSec is adjusted to min(startOffsetSec+120, endOffsetSec). */ private @JsonProperty("videoEndOffsetSec") Integer videoEndOffsetSec; - + /** * The interval of the video the embedding will be generated. The minimum value for interval_sec is 4. * If the interval is less than 4, an InvalidArgumentError is returned. There are no limitations on the maximum value diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java index 327d7950c27..357b84e4ab3 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelName.java @@ -75,4 +75,4 @@ public String getDescription() { return this.description; } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java index baaeedd8acb..e8627f3d625 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/TestVertexAiTextEmbeddingModel.java @@ -69,4 +69,4 @@ protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest reque return super.getPredictRequestBuilder(request, endpointName, finalOptions); } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java index 5757fe5a4fe..9d2a2bd07b5 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java @@ -41,11 +41,11 @@ import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertThrows; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Mark Pollack @@ -81,7 +81,7 @@ public void setUp() { VertexAiTextEmbeddingOptions.builder().build(), this.retryTemplate); this.embeddingModel.setMockPredictionServiceClient(this.mockPredictionServiceClient); this.embeddingModel.setMockPredictRequestBuilder(this.mockPredictRequestBuilder); - when(this.mockPredictRequestBuilder.build()).thenReturn(PredictRequest.getDefaultInstance()); + given(this.mockPredictRequestBuilder.build()).willReturn(PredictRequest.getDefaultInstance()); } @Test @@ -112,9 +112,9 @@ public void vertexAiEmbeddingTransientError() { .build(); // Setup the mock PredictionServiceClient - when(this.mockPredictionServiceClient.predict(any())).thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(mockResponse); + given(this.mockPredictionServiceClient.predict(any())).willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(mockResponse); EmbeddingResponse result = this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null)); @@ -130,11 +130,11 @@ public void vertexAiEmbeddingTransientError() { @Test public void vertexAiEmbeddingNonTransientError() { // Setup the mock PredictionServiceClient to throw a non-transient error - when(this.mockPredictionServiceClient.predict(any())).thenThrow(new RuntimeException("Non Transient Error")); + given(this.mockPredictionServiceClient.predict(any())).willThrow(new RuntimeException("Non Transient Error")); // Assert that a RuntimeException is thrown and not retried - assertThrows(RuntimeException.class, - () -> this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null))); + assertThatThrownBy(() -> this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null))) + .isInstanceOf(RuntimeException.class); // Verify that predict was called only once (no retries for non-transient errors) verify(this.mockPredictionServiceClient, times(1)).predict(any()); diff --git a/models/spring-ai-vertex-ai-gemini/pom.xml b/models/spring-ai-vertex-ai-gemini/pom.xml index 230c5cd6796..0f503e4e781 100644 --- a/models/spring-ai-vertex-ai-gemini/pom.xml +++ b/models/spring-ai-vertex-ai-gemini/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 67d6bed36af..956225dbd93 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -366,9 +366,7 @@ public Flux stream(Prompt prompt) { Flux chatResponseFlux = Flux.just(chatResponse) .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) + .doFinally(s -> observation.stop()) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 574574a75eb..aec4649f303 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -232,8 +232,8 @@ public String getResponseMimeType() { return this.responseMimeType; } - public String setResponseMimeType(String mimeType) { - return this.responseMimeType = mimeType; + public void setResponseMimeType(String mimeType) { + this.responseMimeType = mimeType; } public List getFunctionCallbacks() { diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java index fd3d04106c4..03209da5118 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/aot/VertexAiGeminiRuntimeHints.java @@ -35,8 +35,9 @@ public class VertexAiGeminiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(VertexAiGeminiChatModel.class)) + for (var tr : findJsonAnnotatedClassesInPackage(VertexAiGeminiChatModel.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java index 2d8b69f9861..b341fac134a 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiConstants.java @@ -21,8 +21,12 @@ /** * @author Soby Chacko */ -public class VertexAiGeminiConstants { +public final class VertexAiGeminiConstants { public static final String PROVIDER_NAME = AiProvider.VERTEX_AI.value(); + private VertexAiGeminiConstants() { + + } + } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index 3925fe12506..d1244b97dff 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -120,7 +120,7 @@ public void promptOptionsTools() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName(TOOL_FUNCTION_NAME) .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build()), null); @@ -148,7 +148,7 @@ public void defaultOptionsTools() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName(TOOL_FUNCTION_NAME) .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build()); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java index 9ab82aa64b7..d304d44105b 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java @@ -62,4 +62,4 @@ public void setMockGenerativeModel(GenerativeModel mockGenerativeModel) { this.mockGenerativeModel = mockGenerativeModel; } -} \ No newline at end of file +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java index bddb9328a28..288be75e3cf 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java @@ -43,8 +43,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.when; /** * @author Mark Pollack @@ -91,10 +91,10 @@ public void vertexAiGeminiChatTransientError() throws IOException { .build()) .build(); - when(this.mockGenerativeModel.generateContent(any(List.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(mockedResponse); + given(this.mockGenerativeModel.generateContent(any(List.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(mockedResponse); // Call the chat model ChatResponse result = this.chatModel.call(new Prompt("test prompt")); @@ -109,8 +109,8 @@ public void vertexAiGeminiChatTransientError() throws IOException { @Test public void vertexAiGeminiChatNonTransientError() throws Exception { // Set up the mock GenerativeModel to throw a non-transient RuntimeException - when(this.mockGenerativeModel.generateContent(any(List.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.mockGenerativeModel.generateContent(any(List.class))) + .willThrow(new RuntimeException("Non Transient Error")); // Assert that a RuntimeException is thrown when calling the chat model assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("test prompt"))); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java index a7f7521df5a..d79b317cda1 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java @@ -47,7 +47,7 @@ else if (request.location().contains("San Francisco")) { temperature = 30; } - this.logger.info("Request is {}, response temperature is {}", request, temperature); + logger.info("Request is {}, response temperature is {}", request, temperature); return new Response(temperature, Unit.C); } @@ -70,7 +70,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java index d52f184457a..06c0bbd3809 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java @@ -68,18 +68,18 @@ public void functionCallExplicitOpenApiSchema() { { "type": "OBJECT", "properties": { - "location": { + "location": { "type": "STRING", "description": "The city and state e.g. San Francisco, CA" - }, - "unit" : { + }, + "unit" : { "type" : "STRING", "enum" : [ "C", "F" ], "description" : "Temperature unit" - } + } }, "required": ["location", "unit"] - } + } """; var promptOptions = VertexAiGeminiChatOptions.builder() @@ -93,7 +93,7 @@ public void functionCallExplicitOpenApiSchema() { ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); } @@ -123,14 +123,14 @@ public void functionCallTestInferredOpenApiSchema() { ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); ChatResponse response2 = this.chatModel .call(new Prompt("What is the payment status for transaction 696?", promptOptions)); - this.logger.info("Response: {}", response2); + logger.info("Response: {}", response2); assertThat(response2.getResult().getOutput().getContent()).containsIgnoringCase("transaction 696 is PAYED"); @@ -162,14 +162,14 @@ public void functionCallTestInferredOpenApiSchema2() { ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); ChatResponse response2 = this.chatModel .call(new Prompt("What is the payment status for transaction 696?", promptOptions)); - this.logger.info("Response: {}", response2); + logger.info("Response: {}", response2); assertThat(response2.getResult().getOutput().getContent()).containsIgnoringCase("transaction 696 is PAYED"); @@ -203,7 +203,7 @@ public void functionCallTestInferredOpenApiSchemaStream() { .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", responseString); + logger.info("Response: {}", responseString); assertThat(responseString).contains("30", "10", "15"); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index 4d674c8bd5b..12a8bb30dc7 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -105,7 +105,8 @@ public void streamingPaymentStatuses() { // Quota rate try { Thread.sleep(1000); - } catch (InterruptedException e) { + } + catch (InterruptedException e) { } } @@ -135,19 +136,19 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis } private AdvisedRequest before(AdvisedRequest request) { - this.logger.info("System text: \n" + request.systemText()); - this.logger.info("System params: " + request.systemParams()); - this.logger.info("User text: \n" + request.userText()); - this.logger.info("User params:" + request.userParams()); - this.logger.info("Function names: " + request.functionNames()); + logger.info("System text: \n" + request.systemText()); + logger.info("System params: " + request.systemParams()); + logger.info("User text: \n" + request.userText()); + logger.info("User params:" + request.userParams()); + logger.info("Function names: " + request.functionNames()); - this.logger.info("Options: " + request.chatOptions().toString()); + logger.info("Options: " + request.chatOptions().toString()); return request; } private void observeAfter(AdvisedResponse advisedResponse) { - this.logger.info("Response: " + advisedResponse.response()); + logger.info("Response: " + advisedResponse.response()); } } diff --git a/models/spring-ai-vertex-ai-palm2/pom.xml b/models/spring-ai-vertex-ai-palm2/pom.xml index 07455d51dd4..2bf40081458 100644 --- a/models/spring-ai-vertex-ai-palm2/pom.xml +++ b/models/spring-ai-vertex-ai-palm2/pom.xml @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + true + + diff --git a/models/spring-ai-watsonx-ai/pom.xml b/models/spring-ai-watsonx-ai/pom.xml index 3fc195eb470..ca14a148d5c 100644 --- a/models/spring-ai-watsonx-ai/pom.xml +++ b/models/spring-ai-watsonx-ai/pom.xml @@ -15,39 +15,41 @@ ~ limitations under the License. --> - - 4.0.0 - - org.springframework.ai - spring-ai - 1.0.0-SNAPSHOT - ../../pom.xml - + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + - spring-ai-watsonx-ai + spring-ai-watsonx-ai - Spring AI Model - WatsonX AI + Spring AI Model - WatsonX AI - - 17 - 17 - UTF-8 - + + 17 + 17 + UTF-8 + false + - + - - - org.springframework.boot - spring-boot - + + + org.springframework.boot + spring-boot + - - org.springframework.ai - spring-ai-core - ${project.parent.version} - + + org.springframework.ai + spring-ai-core + ${project.parent.version} + org.springframework.ai @@ -55,38 +57,38 @@ ${project.parent.version} - - org.springframework.boot - spring-boot-starter-logging - + + org.springframework.boot + spring-boot-starter-logging + - - com.ibm.cloud - sdk-core - ${ibm.sdk.version} - + + com.ibm.cloud + sdk-core + ${ibm.sdk.version} + - - - org.springframework.boot - spring-boot-starter-test - test - + + + org.springframework.boot + spring-boot-starter-test + test + - - io.projectreactor - reactor-test - test - 3.6.2 - + + io.projectreactor + reactor-test + test + 3.6.2 + - - org.springframework.ai - spring-ai-test - ${project.version} - test - + + org.springframework.ai + spring-ai-test + ${project.version} + test + - + diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index 9a113da5795..5273f6de317 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -46,153 +46,163 @@ public class WatsonxAiChatOptions implements ChatOptions { - @JsonIgnore - private final ObjectMapper mapper = new ObjectMapper(); - - /** - * The temperature of the model. Increasing the temperature will - * make the model answer more creatively. (Default: 0.7) - */ - @JsonProperty("temperature") private Double temperature; - - /** - * Works together with top-k. A higher value (e.g., 0.95) will lead to - * more diverse text, while a lower value (e.g., 0.2) will generate more focused and - * conservative text. (Default: 1.0) - */ - @JsonProperty("top_p") private Double topP; - - /** - * Reduces the probability of generating nonsense. A higher value (e.g. - * 100) will give more diverse answers, while a lower value (e.g. 10) will be more - * conservative. (Default: 50) - */ - @JsonProperty("top_k") private Integer topK; - - /** - * Decoding is the process that a model uses to choose the tokens in the generated output. - * Choose one of the following decoding options: - * - * Greedy: Selects the token with the highest probability at each step of the decoding process. - * Greedy decoding produces output that closely matches the most common language in the model's pretraining - * data and in your prompt text, which is desirable in less creative or fact-based use cases. A weakness of - * greedy decoding is that it can cause repetitive loops in the generated output. - * - * Sampling decoding: Offers more variability in how tokens are selected. - * With sampling decoding, the model samples tokens, meaning the model chooses a subset of tokens, - * and then one token is chosen randomly from this subset to be added to the output text. Sampling adds - * variability and randomness to the decoding process, which can be desirable in creative use cases. - * However, with greater variability comes a greater risk of incorrect or nonsensical output. - * (Default: greedy) - */ - @JsonProperty("decoding_method") private String decodingMethod; - - /** - * Sets the limit of tokens that the LLM follow. (Default: 20) - */ - @JsonProperty("max_new_tokens") private Integer maxNewTokens; - - /** - * Sets how many tokens must the LLM generate. (Default: 0) - */ - @JsonProperty("min_new_tokens") private Integer minNewTokens; - - /** - * Sets when the LLM should stop. - * (e.g., ["\n\n\n"]) then when the LLM generates three consecutive line breaks it will terminate. - * Stop sequences are ignored until after the number of tokens that are specified in the Min tokens parameter are generated. - */ - @JsonProperty("stop_sequences") private List stopSequences; - - /** - * Sets how strongly to penalize repetitions. A higher value - * (e.g., 1.8) will penalize repetitions more strongly, while a lower value (e.g., - * 1.1) will be more lenient. (Default: 1.0) - */ - @JsonProperty("repetition_penalty") private Double repetitionPenalty; - - /** - * Produce repeatable results, set the same random seed value every time. (Default: randomly generated) - */ - @JsonProperty("random_seed") private Integer randomSeed; - - /** - * Model is the identifier of the LLM Model to be used - */ - @JsonProperty("model") private String model; - - /** - * Set additional request params (some model have non-predefined options) - */ - @JsonProperty("additional") - private Map additional = new HashMap<>(); + @JsonIgnore + private final ObjectMapper mapper = new ObjectMapper(); + + /** + * The temperature of the model. Increasing the temperature will + * make the model answer more creatively. (Default: 0.7) + */ + @JsonProperty("temperature") + private Double temperature; + + /** + * Works together with top-k. A higher value (e.g., 0.95) will lead to + * more diverse text, while a lower value (e.g., 0.2) will generate more focused and + * conservative text. (Default: 1.0) + */ + @JsonProperty("top_p") + private Double topP; + + /** + * Reduces the probability of generating nonsense. A higher value (e.g. + * 100) will give more diverse answers, while a lower value (e.g. 10) will be more + * conservative. (Default: 50) + */ + @JsonProperty("top_k") + private Integer topK; + + /** + * Decoding is the process that a model uses to choose the tokens in the generated output. + * Choose one of the following decoding options: + * + * Greedy: Selects the token with the highest probability at each step of the decoding process. + * Greedy decoding produces output that closely matches the most common language in the model's pretraining + * data and in your prompt text, which is desirable in less creative or fact-based use cases. A weakness of + * greedy decoding is that it can cause repetitive loops in the generated output. + * + * Sampling decoding: Offers more variability in how tokens are selected. + * With sampling decoding, the model samples tokens, meaning the model chooses a subset of tokens, + * and then one token is chosen randomly from this subset to be added to the output text. Sampling adds + * variability and randomness to the decoding process, which can be desirable in creative use cases. + * However, with greater variability comes a greater risk of incorrect or nonsensical output. + * (Default: greedy) + */ + @JsonProperty("decoding_method") + private String decodingMethod; + + /** + * Sets the limit of tokens that the LLM follow. (Default: 20) + */ + @JsonProperty("max_new_tokens") + private Integer maxNewTokens; + + /** + * Sets how many tokens must the LLM generate. (Default: 0) + */ + @JsonProperty("min_new_tokens") + private Integer minNewTokens; + + /** + * Sets when the LLM should stop. + * (e.g., ["\n\n\n"]) then when the LLM generates three consecutive line breaks it will terminate. + * Stop sequences are ignored until after the number of tokens that are specified in the Min tokens parameter are generated. + */ + @JsonProperty("stop_sequences") + private List stopSequences; + + /** + * Sets how strongly to penalize repetitions. A higher value + * (e.g., 1.8) will penalize repetitions more strongly, while a lower value (e.g., + * 1.1) will be more lenient. (Default: 1.0) + */ + @JsonProperty("repetition_penalty") + private Double repetitionPenalty; + + /** + * Produce repeatable results, set the same random seed value every time. (Default: randomly generated) + */ + @JsonProperty("random_seed") + private Integer randomSeed; + + /** + * Model is the identifier of the LLM Model to be used + */ + @JsonProperty("model") + private String model; + + /** + * Set additional request params (some model have non-predefined options) + */ + @JsonProperty("additional") + private Map additional = new HashMap<>(); public static Builder builder() { return new Builder(); } - /** - * Filter out the non-supported fields from the options. - * @param options The options to filter. - * @return The filtered options. - */ - public static Map filterNonSupportedFields(Map options) { - return options.entrySet().stream() - .filter(e -> !e.getKey().equals("model")) - .filter(e -> e.getValue() != null) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - - public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) { - return WatsonxAiChatOptions.builder() - .withTemperature(fromOptions.getTemperature()) - .withTopP(fromOptions.getTopP()) - .withTopK(fromOptions.getTopK()) - .withDecodingMethod(fromOptions.getDecodingMethod()) - .withMaxNewTokens(fromOptions.getMaxNewTokens()) - .withMinNewTokens(fromOptions.getMinNewTokens()) - .withStopSequences(fromOptions.getStopSequences()) - .withRepetitionPenalty(fromOptions.getRepetitionPenalty()) - .withRandomSeed(fromOptions.getRandomSeed()) - .withModel(fromOptions.getModel()) - .withAdditionalProperties(fromOptions.getAdditionalProperties()) - .build(); - } + /** + * Filter out the non-supported fields from the options. + * @param options The options to filter. + * @return The filtered options. + */ + public static Map filterNonSupportedFields(Map options) { + return options.entrySet().stream() + .filter(e -> !e.getKey().equals("model")) + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) { + return WatsonxAiChatOptions.builder() + .withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withDecodingMethod(fromOptions.getDecodingMethod()) + .withMaxNewTokens(fromOptions.getMaxNewTokens()) + .withMinNewTokens(fromOptions.getMinNewTokens()) + .withStopSequences(fromOptions.getStopSequences()) + .withRepetitionPenalty(fromOptions.getRepetitionPenalty()) + .withRandomSeed(fromOptions.getRandomSeed()) + .withModel(fromOptions.getModel()) + .withAdditionalProperties(fromOptions.getAdditionalProperties()) + .build(); + } @Override - public Double getTemperature() { - return this.temperature; - } + public Double getTemperature() { + return this.temperature; + } - public void setTemperature(Double temperature) { - this.temperature = temperature; - } + public void setTemperature(Double temperature) { + this.temperature = temperature; + } @Override - public Double getTopP() { - return this.topP; - } + public Double getTopP() { + return this.topP; + } - public void setTopP(Double topP) { - this.topP = topP; - } + public void setTopP(Double topP) { + this.topP = topP; + } @Override - public Integer getTopK() { - return this.topK; - } + public Integer getTopK() { + return this.topK; + } - public void setTopK(Integer topK) { - this.topK = topK; - } + public void setTopK(Integer topK) { + this.topK = topK; + } - public String getDecodingMethod() { - return this.decodingMethod; - } + public String getDecodingMethod() { + return this.decodingMethod; + } - public void setDecodingMethod(String decodingMethod) { - this.decodingMethod = decodingMethod; - } + public void setDecodingMethod(String decodingMethod) { + this.decodingMethod = decodingMethod; + } @Override @JsonIgnore @@ -205,36 +215,36 @@ public void setMaxTokens(Integer maxTokens) { setMaxNewTokens(maxTokens); } - public Integer getMaxNewTokens() { - return this.maxNewTokens; - } + public Integer getMaxNewTokens() { + return this.maxNewTokens; + } - public void setMaxNewTokens(Integer maxNewTokens) { - this.maxNewTokens = maxNewTokens; - } + public void setMaxNewTokens(Integer maxNewTokens) { + this.maxNewTokens = maxNewTokens; + } - public Integer getMinNewTokens() { - return this.minNewTokens; - } + public Integer getMinNewTokens() { + return this.minNewTokens; + } - public void setMinNewTokens(Integer minNewTokens) { - this.minNewTokens = minNewTokens; - } + public void setMinNewTokens(Integer minNewTokens) { + this.minNewTokens = minNewTokens; + } @Override public List getStopSequences() { - return this.stopSequences; - } + return this.stopSequences; + } - public void setStopSequences(List stopSequences) { - this.stopSequences = stopSequences; - } + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } @Override @JsonIgnore public Double getPresencePenalty() { - return getRepetitionPenalty(); - } + return getRepetitionPenalty(); + } @JsonIgnore public void setPresencePenalty(Double presencePenalty) { @@ -242,144 +252,144 @@ public void setPresencePenalty(Double presencePenalty) { } public Double getRepetitionPenalty() { - return this.repetitionPenalty; - } + return this.repetitionPenalty; + } - public void setRepetitionPenalty(Double repetitionPenalty) { - this.repetitionPenalty = repetitionPenalty; - } + public void setRepetitionPenalty(Double repetitionPenalty) { + this.repetitionPenalty = repetitionPenalty; + } - public Integer getRandomSeed() { - return this.randomSeed; - } + public Integer getRandomSeed() { + return this.randomSeed; + } - public void setRandomSeed(Integer randomSeed) { - this.randomSeed = randomSeed; - } + public void setRandomSeed(Integer randomSeed) { + this.randomSeed = randomSeed; + } @Override - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - @JsonAnyGetter - public Map getAdditionalProperties() { - return this.additional.entrySet().stream() - .collect(Collectors.toMap( - entry -> toSnakeCase(entry.getKey()), - Map.Entry::getValue - )); - } - - @JsonAnySetter - public void addAdditionalProperty(String key, Object value) { - this.additional.put(key, value); - } + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @JsonAnyGetter + public Map getAdditionalProperties() { + return this.additional.entrySet().stream() + .collect(Collectors.toMap( + entry -> toSnakeCase(entry.getKey()), + Map.Entry::getValue + )); + } + + @JsonAnySetter + public void addAdditionalProperty(String key, Object value) { + this.additional.put(key, value); + } @Override @JsonIgnore public Double getFrequencyPenalty() { - return null; - } - - /** - * Convert the {@link WatsonxAiChatOptions} object to a {@link Map} of key/value pairs. - * @return The {@link Map} of key/value pairs. - */ - public Map toMap() { - try { - var json = this.mapper.writeValueAsString(this); - var map = this.mapper.readValue(json, new TypeReference>() {}); - map.remove("additional"); - - return map; - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - private String toSnakeCase(String input) { - return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null; - } - - @Override - public WatsonxAiChatOptions copy() { - return fromOptions(this); - } - - public static class Builder { - - WatsonxAiChatOptions options = new WatsonxAiChatOptions(); - - public Builder withTemperature(Double temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withTopP(Double topP) { - this.options.topP = topP; - return this; - } - - public Builder withTopK(Integer topK) { - this.options.topK = topK; - return this; - } - - public Builder withDecodingMethod(String decodingMethod) { - this.options.decodingMethod = decodingMethod; - return this; - } - - public Builder withMaxNewTokens(Integer maxNewTokens) { - this.options.maxNewTokens = maxNewTokens; - return this; - } - - public Builder withMinNewTokens(Integer minNewTokens) { - this.options.minNewTokens = minNewTokens; - return this; - } - - public Builder withStopSequences(List stopSequences) { - this.options.stopSequences = stopSequences; - return this; - } - - public Builder withRepetitionPenalty(Double repetitionPenalty) { - this.options.repetitionPenalty = repetitionPenalty; - return this; - } - - public Builder withRandomSeed(Integer randomSeed) { - this.options.randomSeed = randomSeed; - return this; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withAdditionalProperty(String key, Object value) { - this.options.additional.put(key, value); - return this; - } - - public Builder withAdditionalProperties(Map properties) { - this.options.additional.putAll(properties); - return this; - } - - public WatsonxAiChatOptions build() { - return this.options; - } - } + return null; + } + + /** + * Convert the {@link WatsonxAiChatOptions} object to a {@link Map} of key/value pairs. + * @return The {@link Map} of key/value pairs. + */ + public Map toMap() { + try { + var json = this.mapper.writeValueAsString(this); + var map = this.mapper.readValue(json, new TypeReference>() { }); + map.remove("additional"); + + return map; + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private String toSnakeCase(String input) { + return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null; + } + + @Override + public WatsonxAiChatOptions copy() { + return fromOptions(this); + } + + public static class Builder { + + WatsonxAiChatOptions options = new WatsonxAiChatOptions(); + + public Builder withTemperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder withTopP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.options.topK = topK; + return this; + } + + public Builder withDecodingMethod(String decodingMethod) { + this.options.decodingMethod = decodingMethod; + return this; + } + + public Builder withMaxNewTokens(Integer maxNewTokens) { + this.options.maxNewTokens = maxNewTokens; + return this; + } + + public Builder withMinNewTokens(Integer minNewTokens) { + this.options.minNewTokens = minNewTokens; + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.stopSequences = stopSequences; + return this; + } + + public Builder withRepetitionPenalty(Double repetitionPenalty) { + this.options.repetitionPenalty = repetitionPenalty; + return this; + } + + public Builder withRandomSeed(Integer randomSeed) { + this.options.randomSeed = randomSeed; + return this; + } + + public Builder withModel(String model) { + this.options.model = model; + return this; + } + + public Builder withAdditionalProperty(String key, Object value) { + this.options.additional.put(key, value); + return this; + } + + public Builder withAdditionalProperties(Map properties) { + this.options.additional.putAll(properties); + return this; + } + + public WatsonxAiChatOptions build() { + return this.options; + } + } } -// @formatter:on \ No newline at end of file +// @formatter:on diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java index 18e3ae3617a..a6df8fff27e 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java @@ -98,7 +98,7 @@ WatsonxAiEmbeddingRequest watsonxAiEmbeddingRequest(List inputs, Embeddi ? (WatsonxAiEmbeddingOptions) options : this.defaultOptions; if (!StringUtils.hasText(runtimeOptions.getModel())) { - this.logger.warn("The model cannot be null, using default model instead"); + logger.warn("The model cannot be null, using default model instead"); runtimeOptions = this.defaultOptions; } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java index b78266e7aac..34799ecf8b8 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/aot/WatsonxAiRuntimeHints.java @@ -37,11 +37,12 @@ public class WatsonxAiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(WatsonxAiApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(WatsonxAiApi.class)) { hints.reflection().registerType(tr, mcs); - - for (var tr : findJsonAnnotatedClassesInPackage(WatsonxAiChatOptions.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(WatsonxAiChatOptions.class)) { hints.reflection().registerType(tr, mcs); + } } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java index 7953f8c5608..6f04fecf8b7 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java @@ -46,107 +46,107 @@ // @formatter:off public class WatsonxAiApi { - public static final String WATSONX_REQUEST_CANNOT_BE_NULL = "Watsonx Request cannot be null"; + public static final String WATSONX_REQUEST_CANNOT_BE_NULL = "Watsonx Request cannot be null"; - private static final Log logger = LogFactory.getLog(WatsonxAiApi.class); + private static final Log logger = LogFactory.getLog(WatsonxAiApi.class); - private final RestClient restClient; - private final WebClient webClient; - private final IamAuthenticator iamAuthenticator; - private final String streamEndpoint; - private final String textEndpoint; + private final RestClient restClient; + private final WebClient webClient; + private final IamAuthenticator iamAuthenticator; + private final String streamEndpoint; + private final String textEndpoint; private final String embeddingEndpoint; - private final String projectId; - private IamToken token; - - /** - * Create a new chat api. - * @param baseUrl api base URL. - * @param streamEndpoint streaming generation. - * @param textEndpoint text generation. + private final String projectId; + private IamToken token; + + /** + * Create a new chat api. + * @param baseUrl api base URL. + * @param streamEndpoint streaming generation. + * @param textEndpoint text generation. * @param embeddingEndpoint embedding generation - * @param projectId watsonx.ai project identifier. - * @param IAMToken IBM Cloud IAM token. - * @param restClientBuilder rest client builder. - */ - public WatsonxAiApi( - String baseUrl, - String streamEndpoint, - String textEndpoint, + * @param projectId watsonx.ai project identifier. + * @param IAMToken IBM Cloud IAM token. + * @param restClientBuilder rest client builder. + */ + public WatsonxAiApi( + String baseUrl, + String streamEndpoint, + String textEndpoint, String embeddingEndpoint, - String projectId, - String IAMToken, - RestClient.Builder restClientBuilder - ) { - this.streamEndpoint = streamEndpoint; - this.textEndpoint = textEndpoint; + String projectId, + String IAMToken, + RestClient.Builder restClientBuilder + ) { + this.streamEndpoint = streamEndpoint; + this.textEndpoint = textEndpoint; this.embeddingEndpoint = embeddingEndpoint; - this.projectId = projectId; - this.iamAuthenticator = IamAuthenticator.fromConfiguration(Map.of("APIKEY", IAMToken)); - this.token = this.iamAuthenticator.requestToken(); - - Consumer defaultHeaders = headers -> { - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setAccept(List.of(MediaType.APPLICATION_JSON)); - }; - - this.restClient = restClientBuilder.baseUrl(baseUrl) - .defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER) - .defaultHeaders(defaultHeaders) - .build(); - - this.webClient = WebClient.builder().baseUrl(baseUrl) - .defaultHeaders(defaultHeaders) - .build(); - } - - @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) - public ResponseEntity generate(WatsonxAiChatRequest watsonxAiChatRequest) { - Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL); - - if(this.token.needsRefresh()) { - this.token = this.iamAuthenticator.requestToken(); - } - - return this.restClient.post() - .uri(this.textEndpoint) - .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .body(watsonxAiChatRequest.withProjectId(this.projectId)) - .retrieve() - .toEntity(WatsonxAiChatResponse.class); - } - - @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) - public Flux generateStreaming(WatsonxAiChatRequest watsonxAiChatRequest) { - Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL); - - if(this.token.needsRefresh()) { - this.token = this.iamAuthenticator.requestToken(); - } - - return this.webClient.post() - .uri(this.streamEndpoint) - .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .bodyValue(watsonxAiChatRequest.withProjectId(this.projectId)) - .retrieve() - .bodyToFlux(WatsonxAiChatResponse.class) - .handle((data, sink) -> { - if (logger.isTraceEnabled()) { - logger.trace(data); - } - sink.next(data); - }); - } + this.projectId = projectId; + this.iamAuthenticator = IamAuthenticator.fromConfiguration(Map.of("APIKEY", IAMToken)); + this.token = this.iamAuthenticator.requestToken(); + + Consumer defaultHeaders = headers -> { + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setAccept(List.of(MediaType.APPLICATION_JSON)); + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER) + .defaultHeaders(defaultHeaders) + .build(); + + this.webClient = WebClient.builder().baseUrl(baseUrl) + .defaultHeaders(defaultHeaders) + .build(); + } + + @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) + public ResponseEntity generate(WatsonxAiChatRequest watsonxAiChatRequest) { + Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL); + + if (this.token.needsRefresh()) { + this.token = this.iamAuthenticator.requestToken(); + } + + return this.restClient.post() + .uri(this.textEndpoint) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) + .body(watsonxAiChatRequest.withProjectId(this.projectId)) + .retrieve() + .toEntity(WatsonxAiChatResponse.class); + } + + @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) + public Flux generateStreaming(WatsonxAiChatRequest watsonxAiChatRequest) { + Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL); + + if (this.token.needsRefresh()) { + this.token = this.iamAuthenticator.requestToken(); + } + + return this.webClient.post() + .uri(this.streamEndpoint) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) + .bodyValue(watsonxAiChatRequest.withProjectId(this.projectId)) + .retrieve() + .bodyToFlux(WatsonxAiChatResponse.class) + .handle((data, sink) -> { + if (logger.isTraceEnabled()) { + logger.trace(data); + } + sink.next(data); + }); + } @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) public ResponseEntity embeddings(WatsonxAiEmbeddingRequest request) { Assert.notNull(request, WATSONX_REQUEST_CANNOT_BE_NULL); - if(this.token.needsRefresh()) { + if (this.token.needsRefresh()) { this.token = this.iamAuthenticator.requestToken(); } - return this.restClient.post() + return this.restClient.post() .uri(this.embeddingEndpoint) .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) .body(request.withProjectId(this.projectId)) diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java index 817e9802f2a..96d40275262 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java @@ -32,58 +32,66 @@ */ // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) -public class WatsonxAiChatRequest { - - @JsonProperty("input") - private String input; - @JsonProperty("parameters") - private Map parameters; - @JsonProperty("model_id") - private String modelId = ""; - @JsonProperty("project_id") - private String projectId = ""; - - private WatsonxAiChatRequest(String input, Map parameters, String modelId, String projectId) { - this.input = input; - this.parameters = parameters; - this.modelId = modelId; - this.projectId = projectId; - } - - public static Builder builder(String input) { return new Builder(input); } - - public WatsonxAiChatRequest withProjectId(String projectId) { - this.projectId = projectId; - return this; - } - - public String getInput() { return this.input; } - - public Map getParameters() { return this.parameters; } - - public String getModelId() { return this.modelId; } - - public static class Builder { - public static final String MODEL_PARAMETER_IS_REQUIRED = "Model parameter is required"; - private final String input; - private Map parameters; - private String model = ""; - - public Builder(String input) { - this.input = input; - } - - public Builder withParameters(Map parameters) { - Assert.notNull(parameters.get("model"), MODEL_PARAMETER_IS_REQUIRED); - this.model = parameters.get("model").toString(); - this.parameters = WatsonxAiChatOptions.filterNonSupportedFields(parameters); - return this; - } - - public WatsonxAiChatRequest build() { - return new WatsonxAiChatRequest(this.input, this.parameters, this.model, ""); - } - - } - -} \ No newline at end of file +public final class WatsonxAiChatRequest { + + @JsonProperty("input") + private String input; + @JsonProperty("parameters") + private Map parameters; + @JsonProperty("model_id") + private String modelId = ""; + @JsonProperty("project_id") + private String projectId = ""; + + private WatsonxAiChatRequest(String input, Map parameters, String modelId, String projectId) { + this.input = input; + this.parameters = parameters; + this.modelId = modelId; + this.projectId = projectId; + } + + public static Builder builder(String input) { + return new Builder(input); + } + + public WatsonxAiChatRequest withProjectId(String projectId) { + this.projectId = projectId; + return this; + } + + public String getInput() { + return this.input; + } + + public Map getParameters() { + return this.parameters; + } + + public String getModelId() { + return this.modelId; + } + + public static class Builder { + public static final String MODEL_PARAMETER_IS_REQUIRED = "Model parameter is required"; + private final String input; + private Map parameters; + private String model = ""; + + public Builder(String input) { + this.input = input; + } + + public Builder withParameters(Map parameters) { + Assert.notNull(parameters.get("model"), MODEL_PARAMETER_IS_REQUIRED); + this.model = parameters.get("model").toString(); + this.parameters = WatsonxAiChatOptions.filterNonSupportedFields(parameters); + return this; + } + + public WatsonxAiChatRequest build() { + return new WatsonxAiChatRequest(this.input, this.parameters, this.model, ""); + } + + } + +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java index f90ce643645..888521c7f02 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java @@ -32,8 +32,8 @@ // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) public record WatsonxAiChatResponse( - @JsonProperty("model_id") String modelId, - @JsonProperty("created_at") Date createdAt, - @JsonProperty("results") List results, - @JsonProperty("system") Map system -) {} + @JsonProperty("model_id") String modelId, + @JsonProperty("created_at") Date createdAt, + @JsonProperty("results") List results, + @JsonProperty("system") Map system +) { } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java index ecb67f8d937..8fb4a9716bc 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java @@ -28,8 +28,8 @@ // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) public record WatsonxAiChatResults( - @JsonProperty("generated_text") String generatedText, - @JsonProperty("generated_token_count") Integer generatedTokenCount, - @JsonProperty("input_token_count") Integer inputTokenCount, - @JsonProperty("stop_reason") String stopReason + @JsonProperty("generated_text") String generatedText, + @JsonProperty("generated_token_count") Integer generatedTokenCount, + @JsonProperty("input_token_count") Integer inputTokenCount, + @JsonProperty("stop_reason") String stopReason ) { } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java index 8e8da278dff..b36c82ac6f1 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java @@ -30,7 +30,7 @@ * @since 1.0.0 */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class WatsonxAiEmbeddingRequest { +public final class WatsonxAiEmbeddingRequest { @JsonProperty("model_id") String model; diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java index 449ec8f7349..8f23b766ba3 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/utils/MessageToPromptConverter.java @@ -23,64 +23,63 @@ import org.springframework.ai.chat.messages.MessageType; // @formatter:off -public class MessageToPromptConverter { - - public static final String TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS = "Tool execution results are not supported for watsonx.ai models"; - - private static final String HUMAN_PROMPT = "Human: "; - - private static final String ASSISTANT_PROMPT = "Assistant: "; - - private String humanPrompt = HUMAN_PROMPT; - private String assistantPrompt = ASSISTANT_PROMPT; - - private MessageToPromptConverter() { - } - - public static MessageToPromptConverter create() { - return new MessageToPromptConverter(); - } - - public MessageToPromptConverter withHumanPrompt(String humanPrompt) { - this.humanPrompt = humanPrompt; - return this; - } - - public MessageToPromptConverter withAssistantPrompt(String assistantPrompt) { - this.assistantPrompt = assistantPrompt; - return this; - } - - public String toPrompt(List messages) { - - final String systemMessages = messages.stream() - .filter(message -> message.getMessageType() == MessageType.SYSTEM) - .map(Message::getContent) - .collect(Collectors.joining("\n")); - - final String userMessages = messages.stream() - .filter(message -> message.getMessageType() == MessageType.USER - || message.getMessageType() == MessageType.ASSISTANT) - .map(this::messageToString) - .collect(Collectors.joining("\n")); - - return String.format("%s%n%n%s%n%s", systemMessages, userMessages, this.assistantPrompt).trim(); - } - - protected String messageToString(Message message) { - switch (message.getMessageType()) { - case SYSTEM: - return message.getContent(); - case USER: - return this.humanPrompt + message.getContent(); - case ASSISTANT: - return this.assistantPrompt + message.getContent(); - case TOOL: - throw new IllegalArgumentException(TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS); - } - - throw new IllegalArgumentException("Unknown message type: " + message.getMessageType()); - } - // @formatter:on - -} \ No newline at end of file +public final class MessageToPromptConverter { + + public static final String TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS = "Tool execution results are not supported for watsonx.ai models"; + + private static final String HUMAN_PROMPT = "Human: "; + + private static final String ASSISTANT_PROMPT = "Assistant: "; + + private String humanPrompt = HUMAN_PROMPT; + private String assistantPrompt = ASSISTANT_PROMPT; + + private MessageToPromptConverter() { + } + + public static MessageToPromptConverter create() { + return new MessageToPromptConverter(); + } + + public MessageToPromptConverter withHumanPrompt(String humanPrompt) { + this.humanPrompt = humanPrompt; + return this; + } + + public MessageToPromptConverter withAssistantPrompt(String assistantPrompt) { + this.assistantPrompt = assistantPrompt; + return this; + } + + public String toPrompt(List messages) { + + final String systemMessages = messages.stream() + .filter(message -> message.getMessageType() == MessageType.SYSTEM) + .map(Message::getContent) + .collect(Collectors.joining("\n")); + + final String userMessages = messages.stream() + .filter(message -> message.getMessageType() == MessageType.USER + || message.getMessageType() == MessageType.ASSISTANT) + .map(this::messageToString) + .collect(Collectors.joining("\n")); + + return String.format("%s%n%n%s%n%s", systemMessages, userMessages, this.assistantPrompt).trim(); + } + + protected String messageToString(Message message) { + switch (message.getMessageType()) { + case SYSTEM: + return message.getContent(); + case USER: + return this.humanPrompt + message.getContent(); + case ASSISTANT: + return this.assistantPrompt + message.getContent(); + case TOOL: + throw new IllegalArgumentException(TOOL_EXECUTION_NOT_SUPPORTED_FOR_WAI_MODELS); + } + throw new IllegalArgumentException("Unknown message type: " + message.getMessageType()); + } + // @formatter:on + +} diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java index 4a41f72f706..54c2101afbc 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java @@ -40,8 +40,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * @author Pablo Sanchidrian Herrera @@ -57,9 +57,7 @@ public void testCreateRequestWithNoModelId() { Prompt prompt = new Prompt("Test message", options); - Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> { - WatsonxAiChatRequest request = this.chatModel.request(prompt); - }); + Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> this.chatModel.request(prompt)); } @Test @@ -171,8 +169,8 @@ public void testCallMethod() { List.of(fakeResults), Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning")))); - when(mockChatApi.generate(any(WatsonxAiChatRequest.class))) - .thenReturn(ResponseEntity.of(Optional.of(fakeResponse))); + given(mockChatApi.generate(any(WatsonxAiChatRequest.class))) + .willReturn(ResponseEntity.of(Optional.of(fakeResponse))); Generation expectedGenerator = new Generation("LLM response") .withGenerationMetadata(ChatGenerationMetadata.from("max_tokens", @@ -205,7 +203,7 @@ public void testStreamMethod() { List.of(fakeResultsSecond), null); Flux fakeResponse = Flux.just(fakeResponseFirst, fakeResponseSecond); - when(mockChatApi.generateStreaming(any(WatsonxAiChatRequest.class))).thenReturn(fakeResponse); + given(mockChatApi.generateStreaming(any(WatsonxAiChatRequest.class))).willReturn(fakeResponse); Generation firstGen = new Generation("LLM resp") .withGenerationMetadata(ChatGenerationMetadata.from("max_tokens", diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java index 4e19920ec90..b3001cf22ac 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java @@ -31,8 +31,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class WatsonxAiEmbeddingModelTest { @@ -91,7 +91,7 @@ void singleEmbeddingWithOptions() { inputTokenCount); ResponseEntity mockResponseEntity = ResponseEntity.ok(mockResponse); - when(this.watsonxAiApiMock.embeddings(any(WatsonxAiEmbeddingRequest.class))).thenReturn(mockResponseEntity); + given(this.watsonxAiApiMock.embeddings(any(WatsonxAiEmbeddingRequest.class))).willReturn(mockResponseEntity); assertThat(this.embeddingModel).isNotNull(); diff --git a/models/spring-ai-zhipuai/pom.xml b/models/spring-ai-zhipuai/pom.xml index 59df1857b1f..10bb5cbfcf9 100644 --- a/models/spring-ai-zhipuai/pom.xml +++ b/models/spring-ai-zhipuai/pom.xml @@ -35,6 +35,9 @@ git://github.com/spring-projects/spring-ai.git git@github.com:spring-projects/spring-ai.git + + false + diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index b6b8e06590d..9694f2961fa 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -394,9 +394,8 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; - toolMessage.getResponses().forEach(response -> { - Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"); - }); + toolMessage.getResponses() + .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); return toolMessage.getResponses() .stream() diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index c0c66253a41..82a6ee49137 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -130,7 +130,7 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions, ChatOptions { @JsonIgnore private Boolean proxyToolCalls; - @NestedConfigurationProperty + @NestedConfigurationProperty @JsonIgnore private Map toolContext; // @formatter:on diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java index baa1e8475f7..6def22698e6 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageOptions.java @@ -134,7 +134,7 @@ public String toString() { return "ZhiPuAiImageOptions{model='" + this.model + '\'' + ", user='" + this.user + '\'' + '}'; } - public static class Builder { + public static final class Builder { private final ZhiPuAiImageOptions options; diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java index d5cc2f21e55..75673f1c445 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/aot/ZhiPuAiRuntimeHints.java @@ -38,10 +38,12 @@ public class ZhiPuAiRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { var mcs = MemberCategory.values(); - for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiApi.class)) + for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiApi.class)) { hints.reflection().registerType(tr, mcs); - for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiImageApi.class)) + } + for (var tr : findJsonAnnotatedClassesInPackage(ZhiPuAiImageApi.class)) { hints.reflection().registerType(tr, mcs); + } } } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index 2be99709d4f..3e2d343ccd0 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -262,27 +262,33 @@ public enum ChatCompletionFinishReason { /** * The model hit a natural stop point or a provided stop sequence. */ - @JsonProperty("stop") STOP, + @JsonProperty("stop") + STOP, /** * The maximum number of tokens specified in the request was reached. */ - @JsonProperty("length") LENGTH, + @JsonProperty("length") + LENGTH, /** * The content was omitted due to a flag from our content filters. */ - @JsonProperty("content_filter") CONTENT_FILTER, + @JsonProperty("content_filter") + CONTENT_FILTER, /** * The model called a tool. */ - @JsonProperty("tool_calls") TOOL_CALLS, + @JsonProperty("tool_calls") + TOOL_CALLS, /** * (deprecated) The model called a function. */ - @JsonProperty("function_call") FUNCTION_CALL, + @JsonProperty("function_call") + FUNCTION_CALL, /** * Only for compatibility with Mistral AI API. */ - @JsonProperty("tool_call") TOOL_CALL + @JsonProperty("tool_call") + TOOL_CALL } /** @@ -334,7 +340,8 @@ public enum Type { /** * Function tool type. */ - @JsonProperty("function") FUNCTION + @JsonProperty("function") + FUNCTION } /** @@ -366,7 +373,7 @@ public Function(String description, String name, String jsonSchema) { } } - /** + /** * Creates a model response for the given chat conversation. * * @param messages A list of messages comprising the conversation so far. @@ -393,7 +400,7 @@ public Function(String description, String name, String jsonSchema) { * */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest ( + public record ChatCompletionRequest( @JsonProperty("messages") List messages, @JsonProperty("model") String model, @JsonProperty("max_tokens") Integer maxTokens, @@ -543,19 +550,24 @@ public enum Role { /** * System message. */ - @JsonProperty("system") SYSTEM, + @JsonProperty("system") + SYSTEM, /** * User message. */ - @JsonProperty("user") USER, + @JsonProperty("user") + USER, /** * Assistant message. */ - @JsonProperty("assistant") ASSISTANT, + @JsonProperty("assistant") + ASSISTANT, /** * Tool message. */ - @JsonProperty("tool") TOOL + @JsonProperty("tool") + TOOL + } /** @@ -799,11 +811,18 @@ public record Embedding( public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); } - @Override public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof Embedding embedding1)) return false; - return Objects.equals(this.index, embedding1.index) && Arrays.equals(this.embedding, embedding1.embedding) && Objects.equals(this.object, embedding1.object); + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Embedding embedding1)) { + return false; + } + return Objects.equals(this.index, embedding1.index) && Arrays.equals(this.embedding, embedding1.embedding) && Objects.equals(this.object, embedding1.object); } + @Override public int hashCode() { int result = Objects.hash(this.index, this.object); @@ -811,7 +830,8 @@ public int hashCode() { return result; } - @Override public String toString() { + @Override + public String toString() { return "Embedding{" + "index=" + this.index + ", embedding=" + Arrays.toString(this.embedding) + diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java index 23bfd8404d9..dea718610c7 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiImageApi.java @@ -69,10 +69,9 @@ public ZhiPuAiImageApi(String baseUrl, String zhiPuAiToken, RestClient.Builder r public ZhiPuAiImageApi(String baseUrl, String zhiPuAiToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { - this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> { - h.setBearerAuth(zhiPuAiToken); - // h.setContentType(MediaType.APPLICATION_JSON); - }).defaultStatusHandler(responseErrorHandler).build(); + this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> h.setBearerAuth(zhiPuAiToken) + // h.setContentType(MediaType.APPLICATION_JSON); + ).defaultStatusHandler(responseErrorHandler).build(); } public ResponseEntity createImage(ZhiPuAiImageRequest zhiPuAiImageRequest) { @@ -108,7 +107,7 @@ public String getValue() { // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) - public record ZhiPuAiImageRequest ( + public record ZhiPuAiImageRequest( @JsonProperty("prompt") String prompt, @JsonProperty("model") String model, @JsonProperty("user_id") String user) { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java index 52f2427712e..e1b5bf8648a 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java @@ -30,4 +30,8 @@ public final class ZhiPuApiConstants { public static final String PROVIDER_NAME = AiProvider.ZHIPUAI.value(); + private ZhiPuApiConstants() { + + } + } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java index 90dac9f579f..a4f2c08f717 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ChatCompletionRequestTests.java @@ -70,7 +70,7 @@ public void promptOptionsTools() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName(TOOL_FUNCTION_NAME) .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build()), false); @@ -97,7 +97,7 @@ public void defaultOptionsTools() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName(TOOL_FUNCTION_NAME) .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build()); diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java index c1487282b15..2f32b5ce7b3 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/MockWeatherService.java @@ -65,7 +65,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } @@ -78,8 +78,8 @@ private Unit(String text) { @JsonClassDescription("Weather API request") public record Request(@JsonProperty(required = true, value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, - @JsonProperty(value = "lat") @JsonPropertyDescription("The city latitude") double lat, - @JsonProperty(value = "lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty("lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty("lon") @JsonPropertyDescription("The city longitude") double lon, @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java index 2c6de05af21..4817580acc0 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java @@ -38,7 +38,6 @@ import org.springframework.http.ResponseEntity; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatModel.GLM_4; /** * @author Geng Rong @@ -99,8 +98,9 @@ public void toolFunctionCall() { List messages = new ArrayList<>(List.of(message)); - ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, GLM_4.value, - List.of(functionTool), ToolChoiceBuilder.AUTO); + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, + org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatModel.GLM_4.value, List.of(functionTool), + ToolChoiceBuilder.AUTO); ResponseEntity chatCompletion = this.zhiPuAiApi.chatCompletionEntity(chatCompletionRequest); @@ -129,12 +129,13 @@ public void toolFunctionCall() { } } - var functionResponseRequest = new ChatCompletionRequest(messages, GLM_4.value, List.of(functionTool), + var functionResponseRequest = new ChatCompletionRequest(messages, + org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatModel.GLM_4.value, List.of(functionTool), ToolChoiceBuilder.AUTO); ResponseEntity chatCompletion2 = this.zhiPuAiApi.chatCompletionEntity(functionResponseRequest); - this.logger.info("Final response: " + chatCompletion2.getBody()); + logger.info("Final response: " + chatCompletion2.getBody()); assertThat(Objects.requireNonNull(chatCompletion2.getBody()).choices()).isNotEmpty(); diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java index af2d147505a..b50df0efb00 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -59,7 +59,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.when; +import static org.mockito.BDDMockito.given; /** * @author Geng Rong @@ -101,13 +101,13 @@ public void zhiPuAiChatTransientError() { var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, new ChatCompletionMessage("Response", Role.ASSISTANT), null); - ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, + ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666L, "model", null, null, new ZhiPuAiApi.Usage(10, 10, 10)); - when(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + given(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); var result = this.chatModel.call(new Prompt("text")); @@ -119,8 +119,8 @@ public void zhiPuAiChatTransientError() { @Test public void zhiPuAiChatNonTransientError() { - when(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.zhiPuAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.call(new Prompt("text"))); } @@ -129,13 +129,13 @@ public void zhiPuAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, new ChatCompletionMessage("Response", Role.ASSISTANT), null); - ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666L, "model", null, null); - when(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(Flux.just(expectedChatCompletion)); + given(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(Flux.just(expectedChatCompletion)); var result = this.chatModel.stream(new Prompt("text")); @@ -147,8 +147,8 @@ public void zhiPuAiChatStreamTransientError() { @Test public void zhiPuAiChatStreamNonTransientError() { - when(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.chatModel.stream(new Prompt("text")).collectList().block()); } @@ -158,10 +158,10 @@ public void zhiPuAiEmbeddingTransientError() { EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new ZhiPuAiApi.Usage(10, 10, 10)); - when(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + given(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); var result = this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); @@ -174,8 +174,8 @@ public void zhiPuAiEmbeddingTransientError() { @Test public void zhiPuAiEmbeddingNonTransientError() { - when(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) - .thenThrow(new RuntimeException("Non Transient Error")); + given(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) + .willThrow(new RuntimeException("Non Transient Error")); assertThrows(RuntimeException.class, () -> this.embeddingModel .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); } @@ -183,12 +183,12 @@ public void zhiPuAiEmbeddingNonTransientError() { @Test public void zhiPuAiImageTransientError() { - var expectedResponse = new ZhiPuAiImageResponse(678l, List.of(new Data("url678"))); + var expectedResponse = new ZhiPuAiImageResponse(678L, List.of(new Data("url678"))); - when(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) - .thenThrow(new TransientAiException("Transient Error 1")) - .thenThrow(new TransientAiException("Transient Error 2")) - .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); + given(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(ResponseEntity.of(Optional.of(expectedResponse))); var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); @@ -200,8 +200,8 @@ public void zhiPuAiImageTransientError() { @Test public void zhiPuAiImageNonTransientError() { - when(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) - .thenThrow(new RuntimeException("Transient Error 1")); + given(this.zhiPuAiImageApi.createImage(isA(ZhiPuAiImageRequest.class))) + .willThrow(new RuntimeException("Transient Error 1")); assertThrows(RuntimeException.class, () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java index 5c0b3737491..53e9cd367cb 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java @@ -233,7 +233,7 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -259,7 +259,7 @@ void streamFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); diff --git a/spring-ai-retry/pom.xml b/spring-ai-retry/pom.xml index 848ac898383..e479fa22059 100644 --- a/spring-ai-retry/pom.xml +++ b/spring-ai-retry/pom.xml @@ -35,6 +35,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 200acce9ee1..639a8a04973 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -15,9 +15,9 @@ ~ limitations under the License. --> - + 4.0.0 org.springframework.ai @@ -36,6 +36,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + @@ -461,7 +465,7 @@ oracle-free 1.19.8 test - + org.testcontainers diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java index 10db7578206..3afca10874d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfiguration.java @@ -68,7 +68,7 @@ static class StaticRegionProvider implements AwsRegionProvider { private final Region region; - public StaticRegionProvider(String region) { + StaticRegionProvider(String region) { try { this.region = Region.of(region); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java index 0941200366d..ea892b7b995 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java @@ -64,8 +64,6 @@ public void onError(RetryContext context RetryCallback callback, Throwable throwable) { logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); } - - ; }) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java index 67e0922b65e..060114561ee 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/transformers/TransformersEmbeddingModelProperties.java @@ -29,12 +29,10 @@ import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; -import static org.springframework.ai.autoconfigure.transformers.TransformersEmbeddingModelProperties.CONFIG_PREFIX; - /** * @author Christian Tzolov */ -@ConfigurationProperties(CONFIG_PREFIX) +@ConfigurationProperties(org.springframework.ai.autoconfigure.transformers.TransformersEmbeddingModelProperties.CONFIG_PREFIX) public class TransformersEmbeddingModelProperties { public static final String CONFIG_PREFIX = "spring.ai.embedding.transformer"; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java index f500e1bdf6a..4b9efc2471f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cassandra/CassandraVectorStoreAutoConfiguration.java @@ -86,7 +86,7 @@ public CassandraVectorStore vectorStore(EmbeddingModel embeddingModel, Cassandra public DriverConfigLoaderBuilderCustomizer driverConfigLoaderBuilderCustomizer() { // this replaces spring-ai-cassandra-*.jar!application.conf // as spring-boot autoconfigure will not resolve the default driver configs - return (builder) -> builder.startProfile(CassandraVectorStore.DRIVER_PROFILE_UPDATES) + return builder -> builder.startProfile(CassandraVectorStore.DRIVER_PROFILE_UPDATES) .withString(DefaultDriverOption.REQUEST_CONSISTENCY, "LOCAL_QUORUM") .withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofSeconds(1)) .withBoolean(DefaultDriverOption.REQUEST_DEFAULT_IDEMPOTENCE, true) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java index ac716cbd7d2..431d3aa73b0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreProperties.java @@ -109,4 +109,4 @@ public void setVectorDimensions(long vectorDimensions) { this.vectorDimensions = vectorDimensions; } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java index 9a17543b5db..3044179d6a0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/milvus/MilvusVectorStoreProperties.java @@ -138,14 +138,14 @@ public enum MilvusMetricType { /** * Jaccard distance */ - JACCARD; + JACCARD } public enum MilvusIndexType { INVALID, FLAT, IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW, DISKANN, AUTOINDEX, SCANN, GPU_IVF_FLAT, GPU_IVF_PQ, BIN_FLAT, - BIN_IVF_FLAT, TRIE, STL_SORT; + BIN_IVF_FLAT, TRIE, STL_SORT } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java index 27fd396c4b1..de1c42f3ba7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/oracle/OracleVectorStoreProperties.java @@ -20,8 +20,6 @@ import org.springframework.ai.vectorstore.OracleVectorStore; import org.springframework.boot.context.properties.ConfigurationProperties; -import static org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY; - /** * @author Loïc Lefèvre */ @@ -42,7 +40,7 @@ public class OracleVectorStoreProperties extends CommonVectorStoreProperties { private boolean forcedNormalization; - private int searchAccuracy = DEFAULT_SEARCH_ACCURACY; + private int searchAccuracy = org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY; public String getTableName() { return this.tableName; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java index 0c4e933c1c1..13d503f9456 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java @@ -68,14 +68,14 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AnthropicChatOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), AnthropicChatOptions.builder().withFunction("weatherFunction3").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -98,7 +98,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java index 9dccd4c50c5..aad0cbc2447 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java @@ -66,7 +66,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java index e27e66300ee..4d39e27f447 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/MockWeatherService.java @@ -67,7 +67,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index c50ce1d4a58..eecd8c0251f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -185,38 +185,31 @@ void transcribe() { void chatActivation() { // Disable the chat auto-configuration. - this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false") + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isEmpty()); // The chat auto-configuration is enabled by default. - this.contextRunner.run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); - }); + this.contextRunner.run(context -> assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty()); // Explicitly enable the chat auto-configuration. - this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=true").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=true") + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty()); } @Test void embeddingActivation() { // Disable the embedding auto-configuration. - this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false") + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isEmpty()); // The embedding auto-configuration is enabled by default. - this.contextRunner.run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); - }); + this.contextRunner + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty()); // Explicitly enable the embedding auto-configuration. - this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=true").run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=true") + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty()); } @Test @@ -224,20 +217,15 @@ void audioTranscriptionActivation() { // Disable the transcription auto-configuration. this.contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=false") - .run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty(); - }); + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty()); // The transcription auto-configuration is enabled by default. - this.contextRunner.run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); - }); + this.contextRunner + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty()); // Explicitly enable the transcription auto-configuration. this.contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=true") - .run(context -> { - assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty(); - }); + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java index fa2c77b1f5b..435fc9235bb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java @@ -18,7 +18,11 @@ import org.springframework.util.StringUtils; -public class DeploymentNameUtil { +public final class DeploymentNameUtil { + + private DeploymentNameUtil() { + + } public static String getDeploymentName() { String deploymentName = System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java index dda06c82429..7991bb3e004 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java @@ -39,7 +39,6 @@ import org.springframework.context.annotation.Description; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName; @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") @@ -58,7 +57,8 @@ class FunctionCallWithFunctionBeanIT { @Test void functionCallTest() { this.contextRunner - .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) + .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + + org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName()) .run(context -> { ChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -69,14 +69,14 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("weatherFunction3").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -86,7 +86,8 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { this.contextRunner - .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName()) + .withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + + org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName()) .run(context -> { ChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -97,7 +98,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java index 62071ed4168..178a61d15db 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java @@ -37,7 +37,6 @@ import org.springframework.context.annotation.Configuration; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName; @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") @@ -56,7 +55,8 @@ public class FunctionCallWithFunctionWrapperIT { @Test void functionCallTest() { this.contextRunner - .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) + .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + + org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName()) .run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -67,7 +67,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), AzureOpenAiChatOptions.builder().withFunction("WeatherInfo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30", "10", "15"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java index 00a9145354f..7d5a7b913d6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java @@ -34,7 +34,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName; @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") @@ -52,7 +51,8 @@ public class FunctionCallWithPromptFunctionIT { @Test void functionCallTest() { this.contextRunner - .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + getDeploymentName()) + .withPropertyValues("spring.ai.azure.openai.chat.options.deployment-name=" + + org.springframework.ai.autoconfigure.azure.tool.DeploymentNameUtil.getDeploymentName()) .run(context -> { AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); @@ -69,7 +69,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java index 0d390e57ef0..a6d2658ea0d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java @@ -67,7 +67,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java index e3d9b2ff6eb..99ed8e6c849 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/BedrockAwsConnectionConfigurationIT.java @@ -48,7 +48,7 @@ public void autoConfigureAWSCredentialAndRegionProvider() { "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id()) .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class)) - .run((context) -> { + .run(context -> { var awsCredentialsProvider = context.getBean(AwsCredentialsProvider.class); var awsRegionProvider = context.getBean(AwsRegionProvider.class); @@ -72,7 +72,7 @@ public void autoConfigureWithCustomAWSCredentialAndRegionProvider() { "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id()) .withConfiguration(AutoConfigurations.of(TestAutoConfiguration.class, CustomAwsCredentialsProviderAndAwsRegionProviderAutoConfiguration.class)) - .run((context) -> { + .run(context -> { var awsCredentialsProvider = context.getBean(AwsCredentialsProvider.class); var awsRegionProvider = context.getBean(AwsRegionProvider.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java index 4dd65c789bc..ae04bbd8659 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientAutoConfigurationIT.java @@ -51,23 +51,19 @@ public class ChatClientAutoConfigurationIT { @Test void implicitlyEnabled() { - this.contextRunner.run(context -> { - assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty(); - }); + this.contextRunner.run(context -> assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty()); } @Test void explicitlyEnabled() { - this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=true").run(context -> { - assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=true") + .run(context -> assertThat(context.getBeansOfType(ChatClient.Builder.class)).isNotEmpty()); } @Test void explicitlyDisabled() { - this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=false").run(context -> { - assertThat(context.getBeansOfType(ChatClient.Builder.class)).isEmpty(); - }); + this.contextRunner.withPropertyValues("spring.ai.chat.client.enabled=false") + .run(context -> assertThat(context.getBeansOfType(ChatClient.Builder.class)).isEmpty()); } @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java index 658a758b05c..1a97de7c756 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/client/ChatClientObservationAutoConfigurationTests.java @@ -36,16 +36,14 @@ class ChatClientObservationAutoConfigurationTests { @Test void inputContentFilterDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(ChatClientInputContentObservationFilter.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(ChatClientInputContentObservationFilter.class)); } @Test void inputContentFilterEnabled() { - this.contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true").run(context -> { - assertThat(context).hasSingleBean(ChatClientInputContentObservationFilter.class); - }); + this.contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true") + .run(context -> assertThat(context).hasSingleBean(ChatClientInputContentObservationFilter.class)); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java index cafd64873fc..d1601098168 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/observation/ChatObservationAutoConfigurationTests.java @@ -44,30 +44,25 @@ class ChatObservationAutoConfigurationTests { @Test void meterObservationHandlerEnabled() { - this.contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { - assertThat(context).hasSingleBean(ChatModelMeterObservationHandler.class); - }); + this.contextRunner.withBean(CompositeMeterRegistry.class) + .run(context -> assertThat(context).hasSingleBean(ChatModelMeterObservationHandler.class)); } @Test void meterObservationHandlerDisabled() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(ChatModelMeterObservationHandler.class); - }); + this.contextRunner.run(context -> assertThat(context).doesNotHaveBean(ChatModelMeterObservationHandler.class)); } @Test void promptFilterDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationFilter.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationFilter.class)); } @Test void promptHandlerDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class)); } @Test @@ -75,30 +70,25 @@ void promptHandlerEnabled() { this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.chat.observations.include-prompt=true") - .run(context -> { - assertThat(context).hasSingleBean(ChatModelPromptContentObservationHandler.class); - }); + .run(context -> assertThat(context).hasSingleBean(ChatModelPromptContentObservationHandler.class)); } @Test void promptHandlerDisabled() { - this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-prompt=true").run(context -> { - assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class); - }); + this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-prompt=true") + .run(context -> assertThat(context).doesNotHaveBean(ChatModelPromptContentObservationHandler.class)); } @Test void completionFilterDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(ChatModelCompletionObservationFilter.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(ChatModelCompletionObservationFilter.class)); } @Test void completionHandlerDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class)); } @Test @@ -106,16 +96,13 @@ void completionHandlerEnabled() { this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.chat.observations.include-completion=true") - .run(context -> { - assertThat(context).hasSingleBean(ChatModelCompletionObservationHandler.class); - }); + .run(context -> assertThat(context).hasSingleBean(ChatModelCompletionObservationHandler.class)); } @Test void completionHandlerDisabled() { - this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-completion=true").run(context -> { - assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class); - }); + this.contextRunner.withPropertyValues("spring.ai.chat.observations.include-completion=true") + .run(context -> assertThat(context).doesNotHaveBean(ChatModelCompletionObservationHandler.class)); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java index ad19103371b..42037c9e064 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/embedding/observation/EmbeddingObservationAutoConfigurationTests.java @@ -37,16 +37,14 @@ class EmbeddingObservationAutoConfigurationTests { @Test void meterObservationHandlerEnabled() { - this.contextRunner.withBean(CompositeMeterRegistry.class).run(context -> { - assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class); - }); + this.contextRunner.withBean(CompositeMeterRegistry.class) + .run(context -> assertThat(context).hasSingleBean(EmbeddingModelMeterObservationHandler.class)); } @Test void meterObservationHandlerDisabled() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(EmbeddingModelMeterObservationHandler.class)); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java index b4bd232039a..deb0a22ade3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/image/observation/ImageObservationAutoConfigurationTests.java @@ -36,16 +36,14 @@ class ImageObservationAutoConfigurationTests { @Test void promptFilterDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(ImageModelPromptContentObservationFilter.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(ImageModelPromptContentObservationFilter.class)); } @Test void promptFilterEnabled() { - this.contextRunner.withPropertyValues("spring.ai.image.observations.include-prompt=true").run(context -> { - assertThat(context).hasSingleBean(ImageModelPromptContentObservationFilter.class); - }); + this.contextRunner.withPropertyValues("spring.ai.image.observations.include-prompt=true") + .run(context -> assertThat(context).hasSingleBean(ImageModelPromptContentObservationFilter.class)); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java index 167102cb8c7..a53c1fa59ed 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackInPromptIT.java @@ -66,13 +66,13 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -92,7 +92,7 @@ void streamingFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -106,7 +106,7 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java index 5b33d4c673b..3cd864d164d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java @@ -73,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -81,7 +81,7 @@ void functionCallTest() { response = chatModel.call(new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("weatherFunctionTwo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -104,7 +104,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); }); } @@ -130,7 +130,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -148,7 +148,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java index 780476ea2fe..28d9eb6c689 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWrapperIT.java @@ -69,7 +69,7 @@ void functionCallTest() { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), MiniMaxChatOptions.builder().withFunction("WeatherInfo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -96,7 +96,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -114,7 +114,7 @@ public FunctionCallback weatherFunctionInfo() { return FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("WeatherInfo") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java index 06e8c2b3ec7..ec0ed5c8100 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MiniMaxAutoConfigurationIT.java @@ -66,9 +66,11 @@ void generateStreaming() { this.contextRunner.run(context -> { MiniMaxChatModel chatModel = context.getBean(MiniMaxChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); - String response = responseFlux.collectList().block().stream().map(chatResponse -> { - return chatResponse.getResults().get(0).getOutput().getContent(); - }).collect(Collectors.joining()); + String response = responseFlux.collectList() + .block() + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getContent()) + .collect(Collectors.joining()); assertThat(response).isNotEmpty(); logger.info("Response: " + response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java index 61a5394dba8..bb0ef923c79 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/MockWeatherService.java @@ -67,7 +67,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java index 441f88450ac..68c80242ab1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java @@ -64,9 +64,11 @@ void generateStreaming() { this.contextRunner.run(context -> { MistralAiChatModel chatModel = context.getBean(MistralAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); - String response = responseFlux.collectList().block().stream().map(chatResponse -> { - return chatResponse.getResults().get(0).getOutput().getContent(); - }).collect(Collectors.joining()); + String response = responseFlux.collectList() + .block() + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getContent()) + .collect(Collectors.joining()); assertThat(response).isNotEmpty(); logger.info("Response: " + response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java index caa18b90b3d..965b9f56215 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java @@ -72,7 +72,7 @@ void functionCallTest() { .withFunction("retrievePaymentDate") .build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); @@ -89,13 +89,13 @@ static class Config { @Bean @Description("Get payment status of a transaction") public Function retrievePaymentStatus() { - return (transaction) -> new Status(DATA.get(transaction.transactionId).status()); + return transaction -> new Status(DATA.get(transaction.transactionId).status()); } @Bean @Description("Get payment date of a transaction") public Function retrievePaymentDate() { - return (transaction) -> new Date(DATA.get(transaction.transactionId).date()); + return transaction -> new Date(DATA.get(transaction.transactionId).date()); } public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java index 5a428b91b3d..08e84427980 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java @@ -79,7 +79,7 @@ void functionCallTest() { .withFunction("retrievePaymentDate") .build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); @@ -96,13 +96,13 @@ static class Config { @Bean @Description("Get payment status of a transaction") public Function retrievePaymentStatus() { - return (transaction) -> new Status(DATA.get(transaction.transactionId).status()); + return transaction -> new Status(DATA.get(transaction.transactionId).status()); } @Bean @Description("Get payment date of a transaction") public Function retrievePaymentDate() { - return (transaction) -> new Date(DATA.get(transaction.transactionId).date()); + return transaction -> new Date(DATA.get(transaction.transactionId).date()); } public record Transaction(@JsonProperty(required = true, value = "transaction_id") String transactionId) { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java index 2efdad49317..187d4da6604 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java @@ -78,7 +78,7 @@ public Status apply(Transaction transaction) { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("T1001"); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("paid"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java index f546eb8bbd1..417a2b6a9e6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java @@ -81,7 +81,7 @@ void promptFunctionCall() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15", "15.0"); // assertThat(response.getResult().getOutput().getContent()).contains("30.0", @@ -109,7 +109,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15", "15.0"); }); @@ -137,7 +137,7 @@ public enum Unit { C, F } @JsonInclude(Include.NON_NULL) public record Request( @JsonProperty(required = true, value = "location") String location, - @JsonProperty(required = true, value = "unit") Unit unit) {} + @JsonProperty(required = true, value = "unit") Unit unit) { } // @formatter:on public record Response(double temperature, Unit unit) { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java index 2853c4a461b..574057e3e73 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackInPromptIT.java @@ -67,13 +67,13 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -93,7 +93,7 @@ void streamingFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -107,7 +107,7 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java index e94be42200d..ba1d4d41604 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -73,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -81,7 +81,7 @@ void functionCallTest() { response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("weatherFunctionTwo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -104,7 +104,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); }); } @@ -129,7 +129,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -147,7 +147,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java index 9de829cc77c..c97a2aee236 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWrapperIT.java @@ -71,7 +71,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), MoonshotChatOptions.builder().withFunction("WeatherInfo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -99,7 +99,7 @@ void streamFunctionCallTest() { .map(AssistantMessage::getContent) .filter(Objects::nonNull) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -117,7 +117,7 @@ public FunctionCallback weatherFunctionInfo() { return FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("WeatherInfo") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java index 3d8e96ba6e4..935464dca4e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/MockWeatherService.java @@ -67,7 +67,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java index 3af059f33a4..3472a8e0c18 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/BaseOllamaIT.java @@ -16,10 +16,10 @@ package org.springframework.ai.autoconfigure.ollama; -import org.testcontainers.ollama.OllamaContainer; - import java.time.Duration; +import org.testcontainers.ollama.OllamaContainer; + import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; @@ -42,7 +42,7 @@ public class BaseOllamaIT { * * to the file ".testcontainers.properties" located in your home directory */ - public static boolean isDisabled() { + public boolean isDisabled() { return false; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java index ebabcc72217..75d243ecd8c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java @@ -16,8 +16,12 @@ package org.springframework.ai.autoconfigure.ollama; -public class OllamaImage { +public final class OllamaImage { public static final String IMAGE = "ollama/ollama:0.3.14"; + private OllamaImage() { + + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java index 53f2973cc5e..57c38896279 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java @@ -81,7 +81,7 @@ void functionCallTest() { .withName("CurrentWeatherService") .withDescription( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -108,7 +108,7 @@ void streamingFunctionCallTest() { .withName("CurrentWeatherService") .withDescription( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java index 451a970cb75..ce686de1299 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperIT.java @@ -150,7 +150,7 @@ public FunctionCallback weatherFunctionInfo() { .withName("WeatherInfo") .withDescription( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java index e4a1487c2bd..88154dd3e73 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/MockWeatherService.java @@ -67,7 +67,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index e02e4244276..821a6be49b7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -106,9 +106,11 @@ void generateStreaming() { this.contextRunner.run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); - String response = responseFlux.collectList().block().stream().map(chatResponse -> { - return chatResponse.getResults().get(0).getOutput().getContent(); - }).collect(Collectors.joining()); + String response = responseFlux.collectList() + .block() + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getContent()) + .collect(Collectors.joining()); assertThat(response).isNotEmpty(); logger.info("Response: " + response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java index 03d6af56cc8..d3c3c52ee0b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java @@ -42,15 +42,15 @@ public void responseFormatJsonSchema() { String responseFormatJsonSchema = """ { - "$schema" : "https://json-schema.org/draft/2020-12/schema", - "type" : "object", - "properties" : { - "someString" : { - "type" : "string" - } - }, - "additionalProperties" : false - } + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "object", + "properties" : { + "someString" : { + "type" : "string" + } + }, + "additionalProperties" : false + } """; new ApplicationContextRunner().withPropertyValues( diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index 538c8456abb..b6e2c3024e0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -62,7 +62,7 @@ void functionCallTest() { .call().content(); // @formatter:on - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -87,7 +87,7 @@ public String apply(MockWeatherService.Request request) { }) .call().content(); // @formatter:on - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).contains("18"); }); @@ -109,7 +109,7 @@ void streamingFunctionCallTest() { .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java index 4de98c17762..53bf98c610b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java @@ -65,13 +65,13 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -94,7 +94,7 @@ void streamingFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -108,7 +108,7 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index c4b8438214a..6408edc1b9a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.openai.tool; import java.util.List; @@ -60,7 +61,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { @Test void functionCallWithDirectBiFunction() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -93,7 +94,7 @@ void functionCallWithDirectBiFunction() { @Test void functionCallWithBiFunctionClass() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -126,7 +127,7 @@ void functionCallWithBiFunctionClass() { @Test void functionCallTest() { - contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .run(context -> { OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); @@ -155,7 +156,7 @@ void functionCallTest() { @Test void functionCallWithPortableFunctionCallingOptions() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), "spring.ai.openai.chat.options.temperature=0.1") .run(context -> { @@ -180,7 +181,7 @@ void functionCallWithPortableFunctionCallingOptions() { @Test void streamFunctionCallTest() { - contextRunner + this.contextRunner .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), "spring.ai.openai.chat.options.temperature=0.1") .run(context -> { @@ -238,9 +239,7 @@ public MyBiFunction weatherFunctionWithClassBiFunction() { @Bean @Description("Get the weather in location") public BiFunction weatherFunctionWithContext() { - return (request, context) -> { - return new MockWeatherService().apply(request); - }; + return (request, context) -> new MockWeatherService().apply(request); } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java index 0056fc20ba2..062ce7f7671 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapper2IT.java @@ -64,7 +64,7 @@ void functionCallTest() { .call().content(); // @formatter:on - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -84,7 +84,7 @@ void streamFunctionCallTest() { .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).contains("30", "10", "15"); }); @@ -99,7 +99,7 @@ public FunctionCallback weatherFunctionInfo() { return FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("WeatherInfo") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java index 01a1bd1a93e..b438cd42a50 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java @@ -65,7 +65,7 @@ void functionCallTest() { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("WeatherInfo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -92,7 +92,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -110,7 +110,7 @@ public FunctionCallback weatherFunctionInfo() { return FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("WeatherInfo") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java index f0026ca9f6e..489942fcc20 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/MockWeatherService.java @@ -67,7 +67,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java index 002a5f578bb..9e5b082c3d2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/qianfan/QianFanAutoConfigurationIT.java @@ -46,7 +46,7 @@ /** * @author Geng Rong */ -@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), +@EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"), @EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") }) public class QianFanAutoConfigurationIT { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java index 64759b68091..0c8dba03b7f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java @@ -36,7 +36,7 @@ public class SpringAiRetryAutoConfigurationIT { @Test void testRetryAutoConfiguration() { - this.contextRunner.run((context) -> { + this.contextRunner.run(context -> { assertThat(context).hasSingleBean(RetryTemplate.class); assertThat(context).hasSingleBean(ResponseErrorHandler.class); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java index c3baab91cb6..3f20e19f8db 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java @@ -59,7 +59,7 @@ public void retryCustomProperties() { "spring.ai.retry.on-http-codes=429", "spring.ai.retry.backoff.initial-interval=1000", "spring.ai.retry.backoff.multiplier=2", - "spring.ai.retry.backoff.max-interval=60000" ) + "spring.ai.retry.backoff.max-interval=60000") // @formatter:on .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class)) .run(context -> { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java index c267dd765ac..606dbace005 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/stabilityai/StabilityAiImagePropertiesTests.java @@ -35,7 +35,7 @@ public void chatPropertiesTest() { new ApplicationContextRunner().withPropertyValues( // @formatter:off - "spring.ai.stabilityai.image.api-key=API_KEY", + "spring.ai.stabilityai.image.api-key=API_KEY", "spring.ai.stabilityai.image.base-url=ENDPOINT", "spring.ai.stabilityai.image.options.n=10", "spring.ai.stabilityai.image.options.model=MODEL_XYZ", diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java index 1590d750d3b..8fe308bee76 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/azure/AzureVectorStoreAutoConfigurationIT.java @@ -45,7 +45,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -109,12 +108,12 @@ public void addAndSearchTest() { vectorStore.add(this.documents); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)), hasSize(1)); - assertObservationRegistry(observationRegistry, VectorStoreProvider.AZURE, - VectorStoreObservationContext.Operation.ADD); + org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil + .assertObservationRegistry(observationRegistry, VectorStoreProvider.AZURE, + VectorStoreObservationContext.Operation.ADD); observationRegistry.clear(); List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); @@ -127,19 +126,20 @@ public void addAndSearchTest() { assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "distance"); - assertObservationRegistry(observationRegistry, VectorStoreProvider.AZURE, - VectorStoreObservationContext.Operation.QUERY); + org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil + .assertObservationRegistry(observationRegistry, VectorStoreProvider.AZURE, + VectorStoreObservationContext.Operation.QUERY); observationRegistry.clear(); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)), hasSize(0)); - assertObservationRegistry(observationRegistry, VectorStoreProvider.AZURE, - VectorStoreObservationContext.Operation.DELETE); + org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil + .assertObservationRegistry(observationRegistry, VectorStoreProvider.AZURE, + VectorStoreObservationContext.Operation.DELETE); observationRegistry.clear(); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java index a22b26c30cf..481488d0a30 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.vectorstore.chroma; import java.util.List; @@ -41,8 +42,7 @@ import org.springframework.context.annotation.Configuration; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertThrows; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** * @author Christian Tzolov @@ -66,7 +66,7 @@ public class ChromaVectorStoreAutoConfigurationIT { @Test public void addAndSearchWithFilters() { - contextRunner.withPropertyValues("spring.ai.vectorstore.chroma.initializeSchema=true").run(context -> { + this.contextRunner.withPropertyValues("spring.ai.vectorstore.chroma.initializeSchema=true").run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); TestObservationRegistry observationRegistry = context.getBean(TestObservationRegistry.class); @@ -78,8 +78,8 @@ public void addAndSearchWithFilters() { vectorStore.add(List.of(bgDocument, nlDocument)); - assertObservationRegistry(observationRegistry, VectorStoreProvider.CHROMA, - VectorStoreObservationContext.Operation.ADD); + org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry( + observationRegistry, VectorStoreProvider.CHROMA, VectorStoreObservationContext.Operation.ADD); observationRegistry.clear(); var request = SearchRequest.query("The World").withTopK(5); @@ -128,11 +128,11 @@ public void addAndSearchWithFilters() { @Test public void throwExceptionOnMissingCollectionAndDisabledInitializedSchema() { - contextRunner.withPropertyValues("spring.ai.vectorstore.chroma.initializeSchema=false").run(context -> { - assertThrows( - "Collection TestCollection doesn't exist and won't be created as the initializeSchema is set to false.", - java.lang.RuntimeException.class, () -> context.getBean(VectorStore.class)); - }); + this.contextRunner.withPropertyValues("spring.ai.vectorstore.chroma.initializeSchema=false") + .run(context -> assertThatThrownBy(() -> context.getBean(VectorStore.class)) + .isInstanceOf(RuntimeException.class) + .hasMessage( + "Collection TestCollection doesn't exist and won't be created as the initializeSchema is set to false.")); } @Configuration(proxyBeanMethods = false) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java index 1a8e0920123..82614734b66 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfigurationIT.java @@ -64,9 +64,7 @@ public class CosmosDBVectorStoreAutoConfigurationIT { @BeforeEach public void setup() { - this.contextRunner.run(context -> { - this.vectorStore = context.getBean(VectorStore.class); - }); + this.contextRunner.run(context -> this.vectorStore = context.getBean(VectorStore.class)); } @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java index 54dbc442041..749e7c2cc5f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/gemfire/GemFireVectorStoreAutoConfigurationIT.java @@ -152,9 +152,8 @@ public void addAndSearchTest() { assertObservationRegistry(observationRegistry, VectorStoreProvider.GEMFIRE, VectorStoreObservationContext.Operation.ADD); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)), hasSize(1)); observationRegistry.clear(); List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); @@ -177,9 +176,8 @@ public void addAndSearchTest() { assertObservationRegistry(observationRegistry, VectorStoreProvider.GEMFIRE, VectorStoreObservationContext.Operation.DELETE); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)), hasSize(0)); observationRegistry.clear(); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java index b0e4cdafaf1..9ee0a74c09e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/ObservationTestUtil.java @@ -28,7 +28,11 @@ * @since 1.0.0 */ -public class ObservationTestUtil { +public final class ObservationTestUtil { + + private ObservationTestUtil() { + + } public static void assertObservationRegistry(TestObservationRegistry observationRegistry, VectorStoreProvider vectorStoreProvider, VectorStoreObservationContext.Operation operation) { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java index 29b8a387895..b787ea6e827 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/observation/VectorStoreObservationAutoConfigurationTests.java @@ -40,16 +40,14 @@ class VectorStoreObservationAutoConfigurationTests { @Test void queryResponseFilterDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationFilter.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationFilter.class)); } @Test void queryResponseHandlerDefault() { - this.contextRunner.run(context -> { - assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationHandler.class); - }); + this.contextRunner + .run(context -> assertThat(context).doesNotHaveBean(VectorStoreQueryResponseObservationHandler.class)); } @Test @@ -57,9 +55,7 @@ void queryResponseHandlerEnabled() { this.contextRunner .withBean(OtelTracer.class, OpenTelemetry.noop().getTracer("test"), new OtelCurrentTraceContext(), null) .withPropertyValues("spring.ai.vectorstore.observations.include-query-response=true") - .run(context -> { - assertThat(context).hasSingleBean(VectorStoreQueryResponseObservationHandler.class); - }); + .run(context -> assertThat(context).hasSingleBean(VectorStoreQueryResponseObservationHandler.class)); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java index 7b23132d7b6..7dc1265db21 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/AwsOpenSearchVectorStoreAutoConfigurationIT.java @@ -74,12 +74,12 @@ class AwsOpenSearchVectorStoreAutoConfigurationIT { OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".indexName=" + DOCUMENT_INDEX, OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".mappingJson=" + """ { - "properties":{ - "embedding":{ - "type":"knn_vector", - "dimension":384 - } - } + "properties":{ + "embedding":{ + "type":"knn_vector", + "dimension":384 + } + } } """); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java index 5445f6de7d1..d5b4876076d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/opensearch/OpenSearchVectorStoreAutoConfigurationIT.java @@ -69,12 +69,12 @@ class OpenSearchVectorStoreAutoConfigurationIT { OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".indexName=" + DOCUMENT_INDEX, OpenSearchVectorStoreProperties.CONFIG_PREFIX + ".mappingJson=" + """ { - "properties":{ - "embedding":{ - "type":"knn_vector", - "dimension":384 - } - } + "properties":{ + "embedding":{ + "type":"knn_vector", + "dimension":384 + } + } } """); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java index 25a969fceda..e6f12372ba7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfigurationIT.java @@ -147,9 +147,7 @@ public void customSchemaNames(String schemaTableName) { this.contextRunner .withPropertyValues("spring.ai.vectorstore.pgvector.schema-name=" + schemaName, "spring.ai.vectorstore.pgvector.table-name=" + tableName) - .run(context -> { - assertThat(isFullyQualifiedTableExists(context, schemaName, tableName)).isTrue(); - }); + .run(context -> assertThat(isFullyQualifiedTableExists(context, schemaName, tableName)).isTrue()); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -162,9 +160,7 @@ public void disableSchemaInitialization(String schemaTableName) { .withPropertyValues("spring.ai.vectorstore.pgvector.schema-name=" + schemaName, "spring.ai.vectorstore.pgvector.table-name=" + tableName, "spring.ai.vectorstore.pgvector.initialize-schema=false") - .run(context -> { - assertThat(isFullyQualifiedTableExists(context, schemaName, tableName)).isFalse(); - }); + .run(context -> assertThat(isFullyQualifiedTableExists(context, schemaName, tableName)).isFalse()); } @Configuration(proxyBeanMethods = false) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java index e5cc9f97d80..95ea2339e6b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/pinecone/PineconeVectorStoreAutoConfigurationIT.java @@ -44,7 +44,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; -import static org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry; /** * @author Christian Tzolov @@ -96,12 +95,11 @@ public void addAndSearchTest() { vectorStore.add(this.documents); - assertObservationRegistry(observationRegistry, VectorStoreProvider.PINECONE, - VectorStoreObservationContext.Operation.ADD); + org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry( + observationRegistry, VectorStoreProvider.PINECONE, VectorStoreObservationContext.Operation.ADD); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)), hasSize(1)); observationRegistry.clear(); List results = vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); @@ -114,20 +112,19 @@ public void addAndSearchTest() { assertThat(resultDoc.getMetadata()).hasSize(2); assertThat(resultDoc.getMetadata()).containsKeys("spring", "customDistanceField"); - assertObservationRegistry(observationRegistry, VectorStoreProvider.PINECONE, - VectorStoreObservationContext.Operation.QUERY); + org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry( + observationRegistry, VectorStoreProvider.PINECONE, VectorStoreObservationContext.Operation.QUERY); observationRegistry.clear(); // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - assertObservationRegistry(observationRegistry, VectorStoreProvider.PINECONE, - VectorStoreObservationContext.Operation.DELETE); + org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.assertObservationRegistry( + observationRegistry, VectorStoreProvider.PINECONE, VectorStoreObservationContext.Operation.DELETE); observationRegistry.clear(); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Spring").withTopK(1)), hasSize(0)); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java index a9d9716a515..d2888269cab 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfigurationIT.java @@ -59,9 +59,11 @@ void generateStreaming() { this.contextRunner.run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); - String response = responseFlux.collectList().block().stream().map(chatResponse -> { - return chatResponse.getResults().get(0).getOutput().getContent(); - }).collect(Collectors.joining()); + String response = responseFlux.collectList() + .block() + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getContent()) + .collect(Collectors.joining()); assertThat(response).isNotEmpty(); logger.info("Response: " + response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java index 4d17b12cf32..b65315a749a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java @@ -71,21 +71,21 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction3").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel .call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15"); @@ -111,14 +111,14 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("weatherFunction3").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java index 34688fcef77..7b1f2287cd3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java @@ -68,7 +68,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().withFunction("WeatherInfo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java index 2cde310ab6a..3f848374656 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java @@ -77,7 +77,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -85,7 +85,7 @@ void functionCallTest() { response = chatModel .call(new Prompt(List.of(userMessage), VertexAiGeminiChatOptions.builder().build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java index aa78f759467..4518a5194f5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/MockWeatherService.java @@ -68,7 +68,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java index 049fd61c71b..1017332b342 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java @@ -30,12 +30,12 @@ public class WatsonxAiAutoConfigurationTests { public void propertiesTest() { new ApplicationContextRunner().withPropertyValues( // @formatter:off - "spring.ai.watsonx.ai.base-url=TEST_BASE_URL", - "spring.ai.watsonx.ai.stream-endpoint=ml/v1/text/generation_stream?version=2023-05-29", - "spring.ai.watsonx.ai.text-endpoint=ml/v1/text/generation?version=2023-05-29", - "spring.ai.watsonx.ai.embedding-endpoint=ml/v1/text/embeddings?version=2023-05-29", - "spring.ai.watsonx.ai.projectId=1", - "spring.ai.watsonx.ai.IAMToken=123456") + "spring.ai.watsonx.ai.base-url=TEST_BASE_URL", + "spring.ai.watsonx.ai.stream-endpoint=ml/v1/text/generation_stream?version=2023-05-29", + "spring.ai.watsonx.ai.text-endpoint=ml/v1/text/generation?version=2023-05-29", + "spring.ai.watsonx.ai.embedding-endpoint=ml/v1/text/embeddings?version=2023-05-29", + "spring.ai.watsonx.ai.projectId=1", + "spring.ai.watsonx.ai.IAMToken=123456") // @formatter:on .withConfiguration( AutoConfigurations.of(RestClientAutoConfiguration.class, WatsonxAiAutoConfiguration.class)) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java index f15f82f82e4..992af88753b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/ZhiPuAiAutoConfigurationIT.java @@ -69,9 +69,11 @@ void generateStreaming() { this.contextRunner.run(context -> { ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class); Flux responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello"))); - String response = responseFlux.collectList().block().stream().map(chatResponse -> { - return chatResponse.getResults().get(0).getOutput().getContent(); - }).collect(Collectors.joining()); + String response = responseFlux.collectList() + .block() + .stream() + .map(chatResponse -> chatResponse.getResults().get(0).getOutput().getContent()) + .collect(Collectors.joining()); assertThat(response).isNotEmpty(); logger.info("Response: " + response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java index ca91b63c3d0..e10d5310322 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackInPromptIT.java @@ -67,13 +67,13 @@ void functionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); }); @@ -93,7 +93,7 @@ void streamingFunctionCallTest() { .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build())) .build(); @@ -107,7 +107,7 @@ void streamingFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 5b2657c6918..8c5fcb525c8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -73,7 +73,7 @@ void functionCallTest() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("weatherFunction").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -81,7 +81,7 @@ void functionCallTest() { response = chatModel.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -104,7 +104,7 @@ void functionCallWithPortableFunctionCallingOptions() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); }); } @@ -129,7 +129,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -147,7 +147,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java index 9016104f214..15118dbabfc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWrapperIT.java @@ -70,7 +70,7 @@ void functionCallTest() { ChatResponse response = chatModel.call( new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().withFunction("WeatherInfo").build())); - this.logger.info("Response: {}", response); + logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); @@ -97,7 +97,7 @@ void streamFunctionCallTest() { .map(Generation::getOutput) .map(AssistantMessage::getContent) .collect(Collectors.joining()); - this.logger.info("Response: {}", content); + logger.info("Response: {}", content); assertThat(content).containsAnyOf("30.0", "30"); assertThat(content).containsAnyOf("10.0", "10"); @@ -115,7 +115,7 @@ public FunctionCallback weatherFunctionInfo() { return FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("WeatherInfo") .withDescription("Get the weather in location") - .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .withResponseConverter(response -> "" + response.temp() + response.unit()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java index 75d562648f6..43b4af8a57e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/MockWeatherService.java @@ -67,7 +67,7 @@ public enum Unit { */ public final String unitName; - private Unit(String text) { + Unit(String text) { this.unitName = text; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql b/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql index 0b42b6ff7ea..6cd9780f80e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql +++ b/spring-ai-spring-boot-autoconfigure/src/test/resources/oracle/initialize.sql @@ -18,9 +18,11 @@ WHENEVER SQLERROR EXIT SQL.SQLCODE -- Configure the size of the Vector Pool to 1 GiB. -ALTER SYSTEM SET vector_memory_size=1G SCOPE=SPFILE; +ALTER +SYSTEM SET vector_memory_size=1G SCOPE=SPFILE; -SHUTDOWN ABORT; +SHUTDOWN +ABORT; STARTUP; exit; diff --git a/spring-ai-spring-boot-docker-compose/pom.xml b/spring-ai-spring-boot-docker-compose/pom.xml index 0e714abf082..70115f2a57b 100644 --- a/spring-ai-spring-boot-docker-compose/pom.xml +++ b/spring-ai-spring-boot-docker-compose/pom.xml @@ -35,6 +35,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java index 66e670c467e..190ffa9ad95 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactoryTests.java @@ -37,4 +37,4 @@ void runCreatesConnectionDetails() { assertThat(connectionDetails.getPort()).isGreaterThan(0); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java index 8795dc4f69c..87e11ed4ce9 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaWithTokenDockerComposeConnectionDetailsFactoryTests.java @@ -38,4 +38,4 @@ void runCreatesConnectionDetails() { assertThat(connectionDetails.getKeyToken()).isEqualTo("secret"); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java index c88bf10cd0f..61cbe8a95e1 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/mongo/MongoDbAtlasLocalDockerComposeConnectionDetailsFactoryTests.java @@ -36,4 +36,4 @@ void runCreatesConnectionDetails() { assertThat(connectionDetails.getConnectionString()).isNotNull(); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java index a162d05d23b..097d4d14364 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchDockerComposeConnectionDetailsFactoryTests.java @@ -38,4 +38,4 @@ void runCreatesConnectionDetails() { assertThat(connectionDetails.getPassword()).isEqualTo("D3v3l0p-ment"); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java index c1457232776..96051bfd27f 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/opensearch/OpenSearchEnvironmentTests.java @@ -38,4 +38,4 @@ void getPasswordWhenHasPassword() { assertThat(environment.getPassword()).isEqualTo("secret"); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java index b766c31cbec..baa0d731d6f 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/ai/docker/compose/service/connection/typesense/TypesenseDockerComposeConnectionDetailsFactoryTests.java @@ -39,4 +39,4 @@ void runCreatesConnectionDetails() { assertThat(connectionDetails.getApiKey()).isEqualTo("secret"); } -} \ No newline at end of file +} diff --git a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java index 62dd2f20ee6..fb2e74ae484 100644 --- a/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java +++ b/spring-ai-spring-boot-docker-compose/src/test/java/org/springframework/boot/testsupport/DisabledIfProcessUnavailableCondition.java @@ -61,7 +61,7 @@ public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext con private Stream getAnnotationValue(AnnotatedElement testElement) { return MergedAnnotations.from(testElement, SearchStrategy.TYPE_HIERARCHY) .stream(DisabledIfProcessUnavailable.class) - .map((annotation) -> annotation.getStringArray(MergedAnnotation.VALUE)); + .map(annotation -> annotation.getStringArray(MergedAnnotation.VALUE)); } private void check(String[] command) { diff --git a/spring-ai-spring-boot-testcontainers/pom.xml b/spring-ai-spring-boot-testcontainers/pom.xml index d1d5d14853d..13acbba6858 100644 --- a/spring-ai-spring-boot-testcontainers/pom.xml +++ b/spring-ai-spring-boot-testcontainers/pom.xml @@ -15,121 +15,126 @@ ~ limitations under the License. --> - - 4.0.0 - - org.springframework.ai - spring-ai - 1.0.0-SNAPSHOT - - spring-ai-spring-boot-testcontainers - jar - Spring AI Testcontainers - Spring AI Testcontainers - https://github.com/spring-projects/spring-ai - - - https://github.com/spring-projects/spring-ai - git://github.com/spring-projects/spring-ai.git - git@github.com:spring-projects/spring-ai.git - - - - - - org.springframework.ai - spring-ai-spring-boot-autoconfigure - ${project.parent.version} - - - - com.google.protobuf - protobuf-java - ${protobuf-java.version} - - - - - - org.springframework.boot - spring-boot-starter - - - - org.springframework.boot - spring-boot-testcontainers - - - - org.springframework.ai - spring-ai-openai - ${project.parent.version} - true - - - - org.springframework.ai - spring-ai-ollama - ${project.parent.version} - true - - - - - org.springframework.ai - spring-ai-transformers - ${project.parent.version} - true - - - - - org.springframework.ai - spring-ai-milvus-store - ${project.parent.version} - true - - - - - org.springframework.ai - spring-ai-chroma-store - ${project.parent.version} - true - - - - - org.springframework.ai - spring-ai-weaviate-store - ${project.parent.version} - true - - - - - org.springframework.ai - spring-ai-redis-store - ${project.parent.version} - true - - - - - redis.clients - jedis - 5.1.0 - test - - - - - org.springframework.ai - spring-ai-qdrant-store - ${project.parent.version} - true - + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + + spring-ai-spring-boot-testcontainers + jar + Spring AI Testcontainers + Spring AI Testcontainers + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + false + + + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + com.google.protobuf + protobuf-java + ${protobuf-java.version} + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.boot + spring-boot-testcontainers + + + + org.springframework.ai + spring-ai-openai + ${project.parent.version} + true + + + + org.springframework.ai + spring-ai-ollama + ${project.parent.version} + true + + + + + org.springframework.ai + spring-ai-transformers + ${project.parent.version} + true + + + + + org.springframework.ai + spring-ai-milvus-store + ${project.parent.version} + true + + + + + org.springframework.ai + spring-ai-chroma-store + ${project.parent.version} + true + + + + + org.springframework.ai + spring-ai-weaviate-store + ${project.parent.version} + true + + + + + org.springframework.ai + spring-ai-redis-store + ${project.parent.version} + true + + + + + redis.clients + jedis + 5.1.0 + test + + + + + org.springframework.ai + spring-ai-qdrant-store + ${project.parent.version} + true + @@ -154,92 +159,92 @@ true - - - - org.springframework.ai - spring-ai-test - ${project.parent.version} - test - - - - org.springframework.boot - spring-boot-starter-test - - - org.skyscreamer - jsonassert - - - test - - - org.springframework.boot - spring-boot-starter-jdbc - test - - - org.postgresql - postgresql - ${postgresql.version} - test - - - - org.testcontainers - testcontainers - true - - - - org.testcontainers - junit-jupiter - test - - - com.vaadin.external.google - android-json - - - - - - com.redis - testcontainers-redis - 2.2.0 - true - - - - org.awaitility - awaitility - test - - - - org.testcontainers - qdrant - true - - - - org.testcontainers - weaviate - true - - - - org.testcontainers - chromadb - true - - - - org.testcontainers - milvus - true - + + + + org.springframework.ai + spring-ai-test + ${project.parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + + + org.skyscreamer + jsonassert + + + test + + + org.springframework.boot + spring-boot-starter-jdbc + test + + + org.postgresql + postgresql + ${postgresql.version} + test + + + + org.testcontainers + testcontainers + true + + + + org.testcontainers + junit-jupiter + test + + + com.vaadin.external.google + android-json + + + + + + com.redis + testcontainers-redis + 2.2.0 + true + + + + org.awaitility + awaitility + test + + + + org.testcontainers + qdrant + true + + + + org.testcontainers + weaviate + true + + + + org.testcontainers + chromadb + true + + + + org.testcontainers + milvus + true + org.testcontainers @@ -248,11 +253,11 @@ true - - org.testcontainers - ollama - true - + + org.testcontainers + ollama + true + org.opensearch @@ -261,6 +266,6 @@ true - + diff --git a/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories b/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories index a4d88b8ef6b..9f5994436f3 100644 --- a/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories +++ b/spring-ai-spring-boot-testcontainers/src/main/resources/META-INF/spring.factories @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - org.springframework.boot.autoconfigure.service.connection.ConnectionDetailsFactory=\ org.springframework.ai.testcontainers.service.connection.chroma.ChromaContainerConnectionDetailsFactory,\ org.springframework.ai.testcontainers.service.connection.milvus.MilvusContainerConnectionDetailsFactory,\ diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java index 2c387f2c1cb..f8509a97aff 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.chroma; import java.util.List; @@ -60,25 +61,25 @@ public void addAndSearchWithFilters() { var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", Map.of("country", "Netherlands")); - vectorStore.add(List.of(bgDocument, nlDocument)); + this.vectorStore.add(List.of(bgDocument, nlDocument)); var request = SearchRequest.query("The World").withTopK(5); - List results = vectorStore.similaritySearch(request); + List results = this.vectorStore.similaritySearch(request); assertThat(results).hasSize(2); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Bulgaria'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Netherlands'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); // Remove all documents from the store - vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); } @Configuration(proxyBeanMethods = false) diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java index efd038e5d38..47491a9f80c 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class ChromaImage { +public final class ChromaImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11"); + private ChromaImage() { + + } + } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java index 86c9e2718a5..a9c64aa4cfc 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithToken2ContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.chroma; import java.util.List; @@ -64,25 +65,25 @@ public void addAndSearchWithFilters() { var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", Map.of("country", "Netherlands")); - vectorStore.add(List.of(bgDocument, nlDocument)); + this.vectorStore.add(List.of(bgDocument, nlDocument)); var request = SearchRequest.query("The World").withTopK(5); - List results = vectorStore.similaritySearch(request); + List results = this.vectorStore.similaritySearch(request); assertThat(results).hasSize(2); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Bulgaria'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Netherlands'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); // Remove all documents from the store - vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); } @Configuration(proxyBeanMethods = false) diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java index fbe935032c3..614024b7714 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaWithTokenContainerConnectionDetailsFactoryTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.testcontainers.service.connection.chroma; import java.util.List; @@ -62,25 +63,25 @@ public void addAndSearchWithFilters() { var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner", Map.of("country", "Netherlands")); - vectorStore.add(List.of(bgDocument, nlDocument)); + this.vectorStore.add(List.of(bgDocument, nlDocument)); var request = SearchRequest.query("The World").withTopK(5); - List results = vectorStore.similaritySearch(request); + List results = this.vectorStore.similaritySearch(request); assertThat(results).hasSize(2); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Bulgaria'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); - results = vectorStore + results = this.vectorStore .similaritySearch(request.withSimilarityThresholdAll().withFilterExpression("country == 'Netherlands'")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); // Remove all documents from the store - vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); + this.vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); } @Configuration(proxyBeanMethods = false) diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java index 168e854e799..83848c93ea8 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/milvus/MilvusImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class MilvusImage { +public final class MilvusImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("milvusdb/milvus:v2.4.9"); + private MilvusImage() { + + } + } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIT.java similarity index 100% rename from spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIt.java rename to spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIT.java diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java index af0cb68f5ad..4e4d26fe482 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class MongoDbImage { +public final class MongoDbImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("mongodb/mongodb-atlas-local:8.0.0"); + private MongoDbImage() { + + } + } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java index c1bce0c70f0..4734e71de85 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class OllamaImage { +public final class OllamaImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.14"); + private OllamaImage() { + + } + } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java index 26a615e6a32..be05889cb54 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/opensearch/OpenSearchImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class OpenSearchImage { +public final class OpenSearchImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("opensearchproject/opensearch:2.17.1"); + private OpenSearchImage() { + + } + } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java index 618e61cfccd..2c36d1e6c0c 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/qdrant/QdrantImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class QdrantImage { +public final class QdrantImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("qdrant/qdrant:v1.9.7"); + private QdrantImage() { + + } + } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java index bf9982363b0..17f8971cf29 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/typesense/TypesenseImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class TypesenseImage { +public final class TypesenseImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("typesense/typesense:27.1"); + private TypesenseImage() { + + } + } diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java index 8157d3c2e8d..945f70a9af0 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/weaviate/WeaviateImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class WeaviateImage { +public final class WeaviateImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("semitechnologies/weaviate:1.25.9"); + private WeaviateImage() { + + } + } diff --git a/spring-ai-spring-cloud-bindings/pom.xml b/spring-ai-spring-cloud-bindings/pom.xml index da3ffe77fb2..60de4a9514e 100644 --- a/spring-ai-spring-cloud-bindings/pom.xml +++ b/spring-ai-spring-cloud-bindings/pom.xml @@ -35,6 +35,10 @@ git@github.com:spring-projects/spring-ai.git + + false + + diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java index bc4dc5087cd..e1d34848c13 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/BindingsValidator.java @@ -26,6 +26,10 @@ final class BindingsValidator { static final String CONFIG_PATH = "spring.ai.cloud.bindings"; + private BindingsValidator() { + + } + /** * Whether the given binding type should be used to contribute properties. */ diff --git a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java index 8afe9e393e2..4114860d7dc 100644 --- a/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java +++ b/spring-ai-spring-cloud-bindings/src/main/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessor.java @@ -42,9 +42,8 @@ public void process(Environment environment, Bindings bindings, Map { - properties.put("spring.ai.ollama.base-url", binding.getSecret().get("uri")); - }); + bindings.filterBindings(TYPE) + .forEach(binding -> properties.put("spring.ai.ollama.base-url", binding.getSecret().get("uri"))); } } diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java index 885d8f25984..9efb9d8897e 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/ChromaBindingsPropertiesProcessorTests.java @@ -27,7 +27,6 @@ import org.springframework.mock.env.MockEnvironment; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; /** * Unit tests for {@link ChromaBindingsPropertiesProcessor}. @@ -38,12 +37,12 @@ class ChromaBindingsPropertiesProcessorTests { private final Bindings bindings = new Bindings(new Binding("test-name", Paths.get("test-path"), // @formatter:off - Map.of( + Map.of( Binding.TYPE, ChromaBindingsPropertiesProcessor.TYPE, "uri", "https://example.net:8000", "username", "itsme", "password", "youknowit" - ))); + ))); // @formatter:on private final MockEnvironment environment = new MockEnvironment(); @@ -61,7 +60,8 @@ void propertiesAreContributed() { @Test void whenDisabledThenPropertiesAreNotContributed() { - this.environment.setProperty("%s.chroma.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty( + "%s.chroma.enabled".formatted(org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH), "false"); new ChromaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); assertThat(this.properties).isEmpty(); diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java index 0c2d1db356b..d99e991bf3e 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/MistralAiBindingsPropertiesProcessorTests.java @@ -27,7 +27,6 @@ import org.springframework.mock.env.MockEnvironment; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; /** * Unit tests for {@link MistralAiBindingsPropertiesProcessor}. @@ -38,11 +37,11 @@ class MistralAiBindingsPropertiesProcessorTests { private final Bindings bindings = new Bindings(new Binding("test-name", Paths.get("test-path"), // @formatter:off - Map.of( - Binding.TYPE, MistralAiBindingsPropertiesProcessor.TYPE, - "api-key", "demo", - "uri", "https://my.mistralai.example.net" - ))); + Map.of( + Binding.TYPE, MistralAiBindingsPropertiesProcessor.TYPE, + "api-key", "demo", + "uri", "https://my.mistralai.example.net" + ))); // @formatter:on private final MockEnvironment environment = new MockEnvironment(); @@ -58,7 +57,9 @@ void propertiesAreContributed() { @Test void whenDisabledThenPropertiesAreNotContributed() { - this.environment.setProperty("%s.mistralai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty( + "%s.mistralai.enabled".formatted(org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH), + "false"); new MistralAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); assertThat(this.properties).isEmpty(); diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java index b308fae9ac0..aa2d55a3445 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OllamaBindingsPropertiesProcessorTests.java @@ -27,7 +27,6 @@ import org.springframework.mock.env.MockEnvironment; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; /** * Unit tests for {@link OllamaBindingsPropertiesProcessor}. @@ -38,10 +37,10 @@ class OllamaBindingsPropertiesProcessorTests { private final Bindings bindings = new Bindings(new Binding("test-name", Paths.get("test-path"), // @formatter:off - Map.of( - Binding.TYPE, OllamaBindingsPropertiesProcessor.TYPE, - "uri", "https://example.net/ollama:11434" - ))); + Map.of( + Binding.TYPE, OllamaBindingsPropertiesProcessor.TYPE, + "uri", "https://example.net/ollama:11434" + ))); // @formatter:on private final MockEnvironment environment = new MockEnvironment(); @@ -56,7 +55,8 @@ void propertiesAreContributed() { @Test void whenDisabledThenPropertiesAreNotContributed() { - this.environment.setProperty("%s.ollama.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty( + "%s.ollama.enabled".formatted(org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH), "false"); new OllamaBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); assertThat(this.properties).isEmpty(); diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java index fdebd11a2ef..75bcf144662 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/OpenAiBindingsPropertiesProcessorTests.java @@ -27,7 +27,6 @@ import org.springframework.mock.env.MockEnvironment; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; /** * Unit tests for {@link OpenAiBindingsPropertiesProcessor}. @@ -38,11 +37,11 @@ class OpenAiBindingsPropertiesProcessorTests { private final Bindings bindings = new Bindings(new Binding("test-name", Paths.get("test-path"), // @formatter:off - Map.of( - Binding.TYPE, OpenAiBindingsPropertiesProcessor.TYPE, - "api-key", "demo", - "uri", "https://my.openai.example.net" - ))); + Map.of( + Binding.TYPE, OpenAiBindingsPropertiesProcessor.TYPE, + "api-key", "demo", + "uri", "https://my.openai.example.net" + ))); // @formatter:on private final MockEnvironment environment = new MockEnvironment(); @@ -58,7 +57,8 @@ void propertiesAreContributed() { @Test void whenDisabledThenPropertiesAreNotContributed() { - this.environment.setProperty("%s.openai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty( + "%s.openai.enabled".formatted(org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH), "false"); new OpenAiBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); assertThat(this.properties).isEmpty(); diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java index 40492754fca..9a5ca99bdde 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/TanzuBindingsPropertiesProcessorTests.java @@ -27,7 +27,6 @@ import org.springframework.mock.env.MockEnvironment; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; /** * Unit tests for {@link TanzuBindingsPropertiesProcessor}. @@ -38,13 +37,13 @@ class TanzuBindingsPropertiesProcessorTests { private final Bindings bindings = new Bindings(new Binding("test-name", Paths.get("test-path"), // @formatter:off - Map.of( - Binding.TYPE, TanzuBindingsPropertiesProcessor.TYPE, - "api-key", "demo", - "uri", "https://my.openai.example.net", - "model-name", "llava1.6", - "model-capabilities", " chat , vision " - )), + Map.of( + Binding.TYPE, TanzuBindingsPropertiesProcessor.TYPE, + "api-key", "demo", + "uri", "https://my.openai.example.net", + "model-name", "llava1.6", + "model-capabilities", " chat , vision " + )), new Binding("test-name2", Paths.get("test-path2"), Map.of( Binding.TYPE, TanzuBindingsPropertiesProcessor.TYPE, @@ -57,11 +56,11 @@ class TanzuBindingsPropertiesProcessorTests { private final Bindings bindingsMissingModelCapabilities = new Bindings( new Binding("test-name", Paths.get("test-path"), // @formatter:off - Map.of( - Binding.TYPE, TanzuBindingsPropertiesProcessor.TYPE, - "api-key", "demo", - "uri", "https://my.openai.example.net" - ))); + Map.of( + Binding.TYPE, TanzuBindingsPropertiesProcessor.TYPE, + "api-key", "demo", + "uri", "https://my.openai.example.net" + ))); // @formatter:on private final MockEnvironment environment = new MockEnvironment(); @@ -89,7 +88,8 @@ void propertiesAreMissingModelCapabilities() { @Test void whenDisabledThenPropertiesAreNotContributed() { - this.environment.setProperty("%s.genai.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty( + "%s.genai.enabled".formatted(org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH), "false"); new TanzuBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); assertThat(this.properties).isEmpty(); diff --git a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java index f48638ff233..d83372addb0 100644 --- a/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java +++ b/spring-ai-spring-cloud-bindings/src/test/java/org/springframework/ai/bindings/WeaviateBindingsPropertiesProcessorTests.java @@ -27,7 +27,6 @@ import org.springframework.mock.env.MockEnvironment; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH; /** * Unit tests for {@link WeaviateBindingsPropertiesProcessor}. @@ -38,11 +37,11 @@ class WeaviateBindingsPropertiesProcessorTests { private final Bindings bindings = new Bindings(new Binding("test-name", Paths.get("test-path"), // @formatter:off - Map.of( - Binding.TYPE, WeaviateBindingsPropertiesProcessor.TYPE, - "uri", "https://example.net:8000", - "api-key", "demo" - ))); + Map.of( + Binding.TYPE, WeaviateBindingsPropertiesProcessor.TYPE, + "uri", "https://example.net:8000", + "api-key", "demo" + ))); // @formatter:on private final MockEnvironment environment = new MockEnvironment(); @@ -59,7 +58,9 @@ void propertiesAreContributed() { @Test void whenDisabledThenPropertiesAreNotContributed() { - this.environment.setProperty("%s.weaviate.enabled".formatted(CONFIG_PATH), "false"); + this.environment.setProperty( + "%s.weaviate.enabled".formatted(org.springframework.ai.bindings.BindingsValidator.CONFIG_PATH), + "false"); new WeaviateBindingsPropertiesProcessor().process(this.environment, this.bindings, this.properties); assertThat(this.properties).isEmpty(); diff --git a/spring-ai-test/pom.xml b/spring-ai-test/pom.xml index 3397a92ae7b..81a928a263a 100644 --- a/spring-ai-test/pom.xml +++ b/spring-ai-test/pom.xml @@ -37,6 +37,7 @@ 17 17 + false diff --git a/src/checkstyle/checkstyle.xml b/src/checkstyle/checkstyle.xml index 0a224dabb03..b84a97688c8 100644 --- a/src/checkstyle/checkstyle.xml +++ b/src/checkstyle/checkstyle.xml @@ -100,7 +100,7 @@ + value="org.springframework.ai.autoconfigure.vectorstore.observation.ObservationTestUtil.*, org.awaitility.Awaitility.*, org.springframework.ai.aot.AiRuntimeHints.*, org.springframework.ai.image.observation.ImageModelObservationDocumentation.*, org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.*, org.springframework.aot.hint.predicate.RuntimeHintsPredicates.*, org.springframework.ai.vectorstore.filter.Filter.ExpressionType.*, org.springframework.ai.chat.observation.ChatModelObservationDocumentation.*, org.assertj.core.api.AssertionsForClassTypes.*, org.junit.jupiter.api.Assertions.*, org.assertj.core.api.Assertions.*, org.junit.Assert.*, org.junit.Assume.*, org.junit.internal.matchers.ThrowableMessageMatcher.*, org.hamcrest.CoreMatchers.*, org.hamcrest.Matchers.*, org.springframework.boot.configurationprocessor.ConfigurationMetadataMatchers.*, org.springframework.boot.configurationprocessor.TestCompiler.*, org.springframework.boot.test.autoconfigure.AutoConfigurationImportedCondition.*, org.mockito.Mockito.*, org.mockito.BDDMockito.*, org.mockito.Matchers.*, org.mockito.ArgumentMatchers.*, org.springframework.restdocs.mockmvc.MockMvcRestDocumentation.*, org.springframework.restdocs.hypermedia.HypermediaDocumentation.*, org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*, org.springframework.test.web.servlet.result.MockMvcResultMatchers.*, org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.*, org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.*, org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.*, org.springframework.hateoas.mvc.ControllerLinkBuilder.linkTo, org.springframework.test.web.client.match.MockRestRequestMatchers.*, org.springframework.test.web.client.response.MockRestResponseCreators.*, org.springframework.web.reactive.function.server.RequestPredicates.*, org.springframework.web.reactive.function.server.RouterFunctions.*, org.springframework.test.web.servlet.setup.MockMvcBuilders.*" /> diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml index 8c1e57ccdea..a65ce671d1a 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml +++ b/vector-stores/spring-ai-azure-cosmos-db-store/pom.xml @@ -39,6 +39,7 @@ 17 17 + false diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java index 294dffd8596..7e8428ad704 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBFilterExpressionConverter.java @@ -40,7 +40,7 @@ class CosmosDBFilterExpressionConverter extends AbstractFilterExpressionConverte private Map metadataFields; - public CosmosDBFilterExpressionConverter(Collection columns) { + CosmosDBFilterExpressionConverter(Collection columns) { this.metadataFields = columns.stream().collect(Collectors.toMap(Function.identity(), Function.identity())); } diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java index d8432fa71d1..97ac8081c74 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java @@ -55,9 +55,7 @@ public class CosmosDBVectorStoreIT { @BeforeEach public void setup() { - this.contextRunner.run(context -> { - this.vectorStore = context.getBean(VectorStore.class); - }); + this.contextRunner.run(context -> this.vectorStore = context.getBean(VectorStore.class)); } @Test diff --git a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties index 82882acded4..0ff11d471fd 100644 --- a/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties +++ b/vector-stores/spring-ai-azure-cosmos-db-store/src/test/resources/application.properties @@ -1,20 +1,4 @@ -# -# Copyright 2023-2024 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - spring.ai.vectorstore.cosmosdb.databaseName=db spring.ai.vectorstore.cosmosdb.containerName=container spring.ai.vectorstore.cosmosdb.key=${COSMOSDB_AI_ENDPOINT} -spring.ai.vectorstore.cosmosdb.uri=${COSMOSDB_AI_KEY} \ No newline at end of file +spring.ai.vectorstore.cosmosdb.uri=${COSMOSDB_AI_KEY} diff --git a/vector-stores/spring-ai-azure-store/pom.xml b/vector-stores/spring-ai-azure-store/pom.xml index 25bf7e5f11e..015199de527 100644 --- a/vector-stores/spring-ai-azure-store/pom.xml +++ b/vector-stores/spring-ai-azure-store/pom.xml @@ -39,6 +39,7 @@ 17 17 + false diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java index ed127ea11cf..a04aa40b1ce 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverter.java @@ -179,4 +179,4 @@ public void doEndGroup(Group group, StringBuilder context) { context.append(")"); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java index 2bd63829050..a97c83740e6 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureAiSearchFilterExpressionConverterTests.java @@ -53,9 +53,8 @@ public void testMissingFilterName() { FilterExpressionConverter converter = new AzureAiSearchFilterExpressionConverter(List.of()); - assertThatThrownBy(() -> { - converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); - }).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG")))) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Not allowed filter identifier name: country"); } diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index 03418bb2028..dc87de60fee 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -128,9 +128,8 @@ public void searchWithFilters() throws InterruptedException { vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2)); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)); - }, hasSize(3)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("The World").withTopK(5)), hasSize(3)); List results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) diff --git a/vector-stores/spring-ai-cassandra-store/pom.xml b/vector-stores/spring-ai-cassandra-store/pom.xml index 1032f363547..ea9892b3810 100644 --- a/vector-stores/spring-ai-cassandra-store/pom.xml +++ b/vector-stores/spring-ai-cassandra-store/pom.xml @@ -39,6 +39,7 @@ 17 17 + false diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java index c12c81dd65a..854afa52b4e 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemory.java @@ -53,7 +53,13 @@ public final class CassandraChatMemory implements ChatMemory { final CassandraChatMemoryConfig conf; - private final PreparedStatement addUserStmt, addAssistantStmt, getStmt, deleteStmt; + private final PreparedStatement addUserStmt; + + private final PreparedStatement addAssistantStmt; + + private final PreparedStatement getStmt; + + private final PreparedStatement deleteStmt; public CassandraChatMemory(CassandraChatMemoryConfig config) { this.conf = config; @@ -71,7 +77,7 @@ public static CassandraChatMemory create(CassandraChatMemoryConfig conf) { @Override public void add(String conversationId, List messages) { final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli()); - messages.forEach((msg) -> { + messages.forEach(msg -> { if (msg.getMetadata().containsKey(CONVERSATION_TS)) { msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement())); } @@ -89,12 +95,7 @@ public void add(String sessionId, Message msg) { msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now()); - PreparedStatement stmt; - switch (msg.getMessageType()) { - case USER -> stmt = this.addUserStmt; - case ASSISTANT -> stmt = this.addAssistantStmt; - default -> throw new IllegalArgumentException("Cant add type " + msg); - } + PreparedStatement stmt = getStatement(msg); List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); BoundStatementBuilder builder = stmt.boundStatementBuilder(); @@ -112,6 +113,14 @@ public void add(String sessionId, Message msg) { this.conf.session.execute(builder.build()); } + PreparedStatement getStatement(Message msg) { + return switch (msg.getMessageType()) { + case USER -> this.addUserStmt; + case ASSISTANT -> this.addAssistantStmt; + default -> throw new IllegalArgumentException("Cant add type " + msg); + }; + } + @Override public void clear(String sessionId) { diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java index 3c9f329b6ec..c8470c897b4 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/chat/memory/CassandraChatMemoryConfig.java @@ -213,7 +213,7 @@ public GenericType javaType() { } - public static class Builder { + public static final class Builder { private CqlSession session = null; @@ -236,7 +236,7 @@ public static class Builder { private boolean disallowSchemaChanges = false; - private SessionIdToPrimaryKeysTranslator primaryKeyTranslator = (sessionId) -> List.of(sessionId); + private SessionIdToPrimaryKeysTranslator primaryKeyTranslator = List::of; private Builder() { } diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java index ddb3104092e..f1f0e8b5b61 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverter.java @@ -44,10 +44,10 @@ class CassandraFilterExpressionConverter extends AbstractFilterExpressionConvert private final Map columnsByName; - public CassandraFilterExpressionConverter(Collection columns) { + CassandraFilterExpressionConverter(Collection columns) { this.columnsByName = columns.stream() - .collect(Collectors.toMap((c) -> c.getName().asInternal(), Function.identity())); + .collect(Collectors.toMap(c -> c.getName().asInternal(), Function.identity())); } private static void doOperand(ExpressionType type, StringBuilder context) { diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java index 349463646f6..3ad7f5a916d 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java @@ -201,7 +201,7 @@ public void doAdd(List documents) { for (var metadataColumn : this.conf.schema.metadataColumns() .stream() - .filter((mc) -> d.getMetadata().containsKey(mc.name())) + .filter(mc -> d.getMetadata().containsKey(mc.name())) .toList()) { builder = builder.set(metadataColumn.name(), d.getMetadata().get(metadataColumn.name()), @@ -307,11 +307,11 @@ private PreparedStatement prepareAddStatement(Set metadataFields) { // metadata fields that are not configured as metadata columns are not added Set fieldsThatAreColumns = new HashSet<>(this.conf.schema.metadataColumns() .stream() - .map((mc) -> mc.name()) - .filter((mc) -> metadataFields.contains(mc)) + .map(mc -> mc.name()) + .filter(mc -> metadataFields.contains(mc)) .toList()); - return this.addStmts.computeIfAbsent(fieldsThatAreColumns, (fields) -> { + return this.addStmts.computeIfAbsent(fieldsThatAreColumns, fields -> { RegularInsert stmt = null; InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table()); @@ -400,7 +400,7 @@ private String getSimilarityMetric() { */ public enum Similarity { - COSINE, DOT_PRODUCT, EUCLIDEAN; + COSINE, DOT_PRODUCT, EUCLIDEAN } diff --git a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java index 65bcba011de..007d2c08c3a 100644 --- a/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java +++ b/vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java @@ -66,7 +66,7 @@ * @author Mick Semb Wever * @since 1.0.0 */ -public class CassandraVectorStoreConfig implements AutoCloseable { +public final class CassandraVectorStoreConfig implements AutoCloseable { public static final String DEFAULT_KEYSPACE_NAME = "springframework"; @@ -181,7 +181,7 @@ void checkSchemaValid(int vectorDimension) { if (m.indexed()) { Preconditions.checkState( - tableMetadata.getIndexes().values().stream().anyMatch((i) -> i.getTarget().equals(m.name())), + tableMetadata.getIndexes().values().stream().anyMatch(i -> i.getTarget().equals(m.name())), "index %s does not exist", m.name()); } } @@ -189,32 +189,32 @@ void checkSchemaValid(int vectorDimension) { } private void ensureIndexesExists() { - { - SimpleStatement indexStmt = SchemaBuilder.createIndex(this.schema.index) - .ifNotExists() - .custom("StorageAttachedIndex") - .onTable(this.schema.keyspace, this.schema.table) - .andColumn(this.schema.embedding) - .build(); - - logger.debug("Executing {}", indexStmt.getQuery()); - this.session.execute(indexStmt); - } + + SimpleStatement indexStmt = SchemaBuilder.createIndex(this.schema.index) + .ifNotExists() + .custom("StorageAttachedIndex") + .onTable(this.schema.keyspace, this.schema.table) + .andColumn(this.schema.embedding) + .build(); + + logger.debug("Executing {}", indexStmt.getQuery()); + this.session.execute(indexStmt); + Stream .concat(this.schema.partitionKeys.stream(), Stream.concat(this.schema.clusteringKeys.stream(), this.schema.metadataColumns.stream())) - .filter((cs) -> cs.indexed()) - .forEach((metadata) -> { + .filter(cs -> cs.indexed()) + .forEach(metadata -> { - SimpleStatement indexStmt = SchemaBuilder.createIndex(String.format("%s_idx", metadata.name())) + SimpleStatement indexStatement = SchemaBuilder.createIndex(String.format("%s_idx", metadata.name())) .ifNotExists() .custom("StorageAttachedIndex") .onTable(this.schema.keyspace, this.schema.table) .andColumn(metadata.name()) .build(); - logger.debug("Executing {}", indexStmt.getQuery()); - this.session.execute(indexStmt); + logger.debug("Executing {}", indexStatement.getQuery()); + this.session.execute(indexStatement); }); } @@ -362,7 +362,7 @@ public boolean indexed() { } - public static class Builder { + public static final class Builder { private CqlSession session = null; @@ -485,8 +485,7 @@ public Builder addMetadataColumns(List columns) { public Builder addMetadataColumn(SchemaColumn column) { - Preconditions.checkArgument( - this.metadataColumns.stream().noneMatch((sc) -> sc.name().equals(column.name())), + Preconditions.checkArgument(this.metadataColumns.stream().noneMatch(sc -> sc.name().equals(column.name())), "A metadata column with name %s has already been added", column.name()); this.metadataColumns.add(column); @@ -532,11 +531,11 @@ public CassandraVectorStoreConfig build() { for (SchemaColumn metadata : this.metadataColumns) { Preconditions.checkArgument( - !this.partitionKeys.stream().anyMatch((c) -> c.name().equals(metadata.name())), + !this.partitionKeys.stream().anyMatch(c -> c.name().equals(metadata.name())), "metadataColumn %s cannot have same name as a partition key", metadata.name()); Preconditions.checkArgument( - !this.clusteringKeys.stream().anyMatch((c) -> c.name().equals(metadata.name())), + !this.clusteringKeys.stream().anyMatch(c -> c.name().equals(metadata.name())), "metadataColumn %s cannot have same name as a clustering key", metadata.name()); Preconditions.checkArgument(!metadata.name().equals(this.contentColumnName), @@ -546,19 +545,19 @@ public CassandraVectorStoreConfig build() { "metadataColumn %s cannot have same name as embedding column name", this.embeddingColumnName); } - { - int primaryKeyColumnsCount = this.partitionKeys.size() + this.clusteringKeys.size(); - String exampleId = this.primaryKeyTranslator.apply(Collections.emptyList()); - List testIdTranslation = this.documentIdTranslator.apply(exampleId); - Preconditions.checkArgument(testIdTranslation.size() == primaryKeyColumnsCount, - "documentIdTranslator results length %s doesn't match number of primary key columns %s", - String.valueOf(testIdTranslation.size()), String.valueOf(primaryKeyColumnsCount)); + int primaryKeyColumnsCount = this.partitionKeys.size() + this.clusteringKeys.size(); + String exampleId = this.primaryKeyTranslator.apply(Collections.emptyList()); + List testIdTranslation = this.documentIdTranslator.apply(exampleId); + + Preconditions.checkArgument(testIdTranslation.size() == primaryKeyColumnsCount, + "documentIdTranslator results length %s doesn't match number of primary key columns %s", + String.valueOf(testIdTranslation.size()), String.valueOf(primaryKeyColumnsCount)); + + Preconditions.checkArgument( + exampleId.equals(this.primaryKeyTranslator.apply(this.documentIdTranslator.apply(exampleId))), + "primaryKeyTranslator is not an inverse function to documentIdTranslator"); - Preconditions.checkArgument( - exampleId.equals(this.primaryKeyTranslator.apply(this.documentIdTranslator.apply(exampleId))), - "primaryKeyTranslator is not an inverse function to documentIdTranslator"); - } return new CassandraVectorStoreConfig(this); } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java index cc70cd97ad9..9bfd6eb2060 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/CassandraImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class CassandraImage { +public final class CassandraImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("cassandra:5.0"); + private CassandraImage() { + + } + } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java index 21d07e4be1b..89db32064ea 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraFilterExpressionConverterTests.java @@ -83,9 +83,8 @@ void testNoSuchColumn() { CassandraFilterExpressionConverter filter = new CassandraFilterExpressionConverter(COLUMNS); - Assertions.assertThrows(IllegalArgumentException.class, () -> { - filter.convertExpression(new Expression(EQ, new Key("unknown_column"), new Value("BG"))); - }); + Assertions.assertThrows(IllegalArgumentException.class, + () -> filter.convertExpression(new Expression(EQ, new Key("unknown_column"), new Value("BG")))); } @Test diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java index b7684bcc6f6..e3a3ee65feb 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraRichSchemaVectorStoreIT.java @@ -55,7 +55,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; /** @@ -95,20 +94,18 @@ class CassandraRichSchemaVectorStoreIT { static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext context, List columnOverrides) throws IOException { - Optional wikiOverride = columnOverrides.stream() - .filter((f) -> "wiki".equals(f.name())) - .findFirst(); + Optional wikiOverride = columnOverrides.stream().filter(f -> "wiki".equals(f.name())).findFirst(); Optional langOverride = columnOverrides.stream() - .filter((f) -> "language".equals(f.name())) + .filter(f -> "language".equals(f.name())) .findFirst(); Optional titleOverride = columnOverrides.stream() - .filter((f) -> "title".equals(f.name())) + .filter(f -> "title".equals(f.name())) .findFirst(); Optional chunkNoOverride = columnOverrides.stream() - .filter((f) -> "chunk_no".equals(f.name())) + .filter(f -> "chunk_no".equals(f.name())) .findFirst(); SchemaColumn wikiSC = wikiOverride.orElse(new SchemaColumn("wiki", DataTypes.TEXT)); @@ -138,9 +135,9 @@ static CassandraVectorStoreConfig.Builder storeBuilder(ApplicationContext contex if (primaryKeys.isEmpty()) { return "test§¶0"; } - return format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); + return java.lang.String.format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); }) - .withDocumentIdTranslator((id) -> { + .withDocumentIdTranslator(id -> { String[] parts = id.split("§¶"); String title = parts[0]; int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; @@ -176,9 +173,8 @@ void ensureSchemaNoCreation() { executeCqlFile(context, "test_wiki_partial_3_schema.cql"); // IllegalStateException: column all_minilm_l6_v2_embedding does not exist - IllegalStateException ise = Assertions.assertThrows(IllegalStateException.class, () -> { - createStore(context, List.of(), true, false); - }); + IllegalStateException ise = Assertions.assertThrows(IllegalStateException.class, + () -> createStore(context, List.of(), true, false)); Assertions.assertEquals("column all_minilm_l6_v2_embedding does not exist", ise.getMessage()); } @@ -194,7 +190,7 @@ void ensureSchemaPartialCreation() { this.contextRunner.run(context -> { int PARTIAL_FILES = 5; for (int i = 0; i < PARTIAL_FILES; ++i) { - executeCqlFile(context, format("test_wiki_partial_%d_schema.cql", i)); + executeCqlFile(context, java.lang.String.format("test_wiki_partial_%d_schema.cql", i)); var wrapper = createStore(context, List.of(), false, false); try { Assertions.assertNotNull(wrapper.store()); @@ -208,9 +204,8 @@ void ensureSchemaPartialCreation() { } } // make sure there's not more files to test - Assertions.assertThrows(IOException.class, () -> { - executeCqlFile(context, format("test_wiki_partial_%d_schema.cql", PARTIAL_FILES)); - }); + Assertions.assertThrows(IOException.class, () -> executeCqlFile(context, + java.lang.String.format("test_wiki_partial_%d_schema.cql", PARTIAL_FILES))); }); } @@ -328,13 +323,12 @@ void searchWithPartitionFilter() throws InterruptedException { assertThat(results).hasSize(3); // cassandra server will throw an error - Assertions.assertThrows(SyntaxError.class, () -> { - store.similaritySearch(SearchRequest.query("Great Dark Spot") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression( - "NOT(wiki == 'simplewiki' && language == 'en' && title == 'Neptune' && id == 1)")); - }); + Assertions.assertThrows(SyntaxError.class, + () -> store.similaritySearch(SearchRequest.query("Great Dark Spot") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression( + "NOT(wiki == 'simplewiki' && language == 'en' && title == 'Neptune' && id == 1)"))); } }); } @@ -348,12 +342,11 @@ void unsearchableFilters() throws InterruptedException { List results = store.similaritySearch(SearchRequest.query("Great Dark Spot").withTopK(5)); assertThat(results).hasSize(3); - Assertions.assertThrows(InvalidQueryException.class, () -> { - store.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("revision == 9385813")); - }); + Assertions.assertThrows(InvalidQueryException.class, + () -> store.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("revision == 9385813"))); } }); } @@ -396,29 +389,26 @@ void searchWithFilters() throws InterruptedException { // note, it is possible to have SAI indexes on primary key columns to // achieve // e.g. searchWithFilterOnPrimaryKeys() - Assertions.assertThrows(InvalidQueryException.class, () -> { - store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY) - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("id > 557 && \"chunk_no\" == 1")); - }); + Assertions.assertThrows(InvalidQueryException.class, + () -> store.similaritySearch(SearchRequest.query(URANUS_ORBIT_QUERY) + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("id > 557 && \"chunk_no\" == 1"))); // cassandra server will throw an error, // as revision is not searchable (i.e. no SAI index on it) - Assertions.assertThrows(SyntaxError.class, () -> { - store.similaritySearch(SearchRequest.query("Great Dark Spot") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("id == 558 || revision == 2020")); - }); + Assertions.assertThrows(SyntaxError.class, + () -> store.similaritySearch(SearchRequest.query("Great Dark Spot") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("id == 558 || revision == 2020"))); // cassandra java-driver will throw an error - Assertions.assertThrows(InvalidQueryException.class, () -> { - store.similaritySearch(SearchRequest.query("Great Dark Spot") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("NOT(id == 557 || revision == 2020)")); - }); + Assertions.assertThrows(InvalidQueryException.class, + () -> store.similaritySearch(SearchRequest.query("Great Dark Spot") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(id == 557 || revision == 2020)"))); } }); } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java index 03dd67c27dc..e17091d5ea0 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java @@ -48,8 +48,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import static java.lang.String.format; -import static java.util.Collections.emptyMap; import static org.assertj.core.api.Assertions.assertThat; /** @@ -197,7 +195,7 @@ void searchWithPartitionFilter() throws InterruptedException { var bgDocument = new Document("BG", "The World is Big and Salvation Lurks Around the Corner", Map.of("year", (short) 2020)); var nlDocument = new Document("NL", "The World is Big and Salvation Lurks Around the Corner", - emptyMap()); + java.util.Collections.emptyMap()); var bgDocument2 = new Document("BG2", "The World is Big and Salvation Lurks Around the Corner", Map.of("year", (short) 2023)); @@ -209,7 +207,8 @@ void searchWithPartitionFilter() throws InterruptedException { results = store.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() - .withFilterExpression(format("%s == 'NL'", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); + .withFilterExpression( + java.lang.String.format("%s == 'NL'", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); @@ -217,7 +216,8 @@ void searchWithPartitionFilter() throws InterruptedException { results = store.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() - .withFilterExpression(format("%s == 'BG2'", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); + .withFilterExpression( + java.lang.String.format("%s == 'BG2'", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); @@ -225,26 +225,25 @@ void searchWithPartitionFilter() throws InterruptedException { results = store.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() - .withFilterExpression( - format("%s == 'BG' && year == 2020", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); + .withFilterExpression(java.lang.String.format("%s == 'BG' && year == 2020", + CassandraVectorStoreConfig.DEFAULT_ID_NAME))); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); // cassandra server will throw an error - Assertions.assertThrows(SyntaxError.class, () -> { - store.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression( - format("NOT(%s == 'BG' && year == 2020)", CassandraVectorStoreConfig.DEFAULT_ID_NAME))); - }); + Assertions.assertThrows(SyntaxError.class, + () -> store.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression(java.lang.String.format("NOT(%s == 'BG' && year == 2020)", + CassandraVectorStoreConfig.DEFAULT_ID_NAME)))); } }); } @Test - void unsearchableFilters() throws InterruptedException { + void unsearchableFilters() { this.contextRunner.run(context -> { try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) { @@ -260,12 +259,11 @@ void unsearchableFilters() throws InterruptedException { List results = store.similaritySearch(SearchRequest.query("The World").withTopK(5)); assertThat(results).hasSize(3); - Assertions.assertThrows(InvalidQueryException.class, () -> { - store.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'NL'")); - }); + Assertions.assertThrows(InvalidQueryException.class, + () -> store.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'NL'"))); } }); } @@ -315,20 +313,18 @@ void searchWithFilters() throws InterruptedException { assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); // cassandra server will throw an error - Assertions.assertThrows(SyntaxError.class, () -> { - store.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("country == 'BG' || year == 2020")); - }); + Assertions.assertThrows(SyntaxError.class, + () -> store.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("country == 'BG' || year == 2020"))); // cassandra server will throw an error - Assertions.assertThrows(SyntaxError.class, () -> { - store.similaritySearch(SearchRequest.query("The World") - .withTopK(5) - .withSimilarityThresholdAll() - .withFilterExpression("NOT(country == 'BG' && year == 2020)")); - }); + Assertions.assertThrows(SyntaxError.class, + () -> store.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country == 'BG' && year == 2020)"))); } }); } diff --git a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java index 7189351da95..8330b3a443f 100644 --- a/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java +++ b/vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/WikiVectorStoreExample.java @@ -34,7 +34,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; /** @@ -109,10 +108,10 @@ public CassandraVectorStore store(CqlSession cqlSession, EmbeddingModel embeddin if (primaryKeys.isEmpty()) { return "test§¶0"; } - return format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); + return java.lang.String.format("%s§¶%s", primaryKeys.get(2), primaryKeys.get(3)); }) - .withDocumentIdTranslator((id) -> { + .withDocumentIdTranslator(id -> { String[] parts = id.split("§¶"); String title = parts[0]; int chunk_no = 0 < parts.length ? Integer.parseInt(parts[1]) : 0; diff --git a/vector-stores/spring-ai-chroma-store/pom.xml b/vector-stores/spring-ai-chroma-store/pom.xml index 51b1fa4d89e..3e7b35c4a8b 100644 --- a/vector-stores/spring-ai-chroma-store/pom.xml +++ b/vector-stores/spring-ai-chroma-store/pom.xml @@ -89,4 +89,4 @@ - \ No newline at end of file + diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java index 7afb13c4bda..90899516ae9 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -209,7 +209,7 @@ public String getCollectionId() { public void afterPropertiesSet() throws Exception { var collection = this.chromaApi.getCollection(this.collectionName); if (collection == null) { - if (initializeSchema) { + if (this.initializeSchema) { collection = this.chromaApi .createCollection(new ChromaApi.CreateCollectionRequest(this.collectionName)); } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java index 51208d3d9cd..ca51285ddec 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class ChromaImage { +public final class ChromaImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11"); + private ChromaImage() { + + } + } diff --git a/vector-stores/spring-ai-elasticsearch-store/pom.xml b/vector-stores/spring-ai-elasticsearch-store/pom.xml index 89cb9e6e6de..824842bd002 100644 --- a/vector-stores/spring-ai-elasticsearch-store/pom.xml +++ b/vector-stores/spring-ai-elasticsearch-store/pom.xml @@ -37,6 +37,9 @@ + 17 + 17 + false 4.0.3 diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java index e7b2c5a01ae..4a7dc588fde 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchAiSearchFilterExpressionConverter.java @@ -152,4 +152,4 @@ public void doEndGroup(Filter.Group group, StringBuilder context) { context.append(")"); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java index 32731f754c4..9c058936b44 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/main/java/org/springframework/ai/vectorstore/ElasticsearchVectorStore.java @@ -56,8 +56,6 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import static java.lang.Math.sqrt; - /** * The ElasticsearchVectorStore class implements the VectorStore interface and provides * functionality for managing and querying documents in Elasticsearch. It uses an @@ -225,7 +223,7 @@ private float calculateDistance(Float score) { // (closest to zero means more accurate), so to make it consistent // with the other functions the reverse is returned applying a "1-" // to the standard transformation - return (float) (1 - (sqrt((1 / score) - 1))); + return (float) (1 - (java.lang.Math.sqrt((1 / score) - 1))); // cosine and dot_product default: return (2 * score) - 1; diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java index db8b68f3b8e..b0a178aca11 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchImage.java @@ -21,9 +21,13 @@ /** * @author Thomas Vitale */ -public class ElasticsearchImage { +public final class ElasticsearchImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName .parse("docker.elastic.co/elasticsearch/elasticsearch:8.15.2"); + private ElasticsearchImage() { + + } + } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java index 7040b97a600..1d074fba087 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/ElasticsearchVectorStoreObservationIT.java @@ -64,8 +64,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.greaterThan; -; - /** * @author Christian Tzolov * @author Thomas Vitale diff --git a/vector-stores/spring-ai-gemfire-store/pom.xml b/vector-stores/spring-ai-gemfire-store/pom.xml index 25ab313c88b..2d1f84a66c1 100644 --- a/vector-stores/spring-ai-gemfire-store/pom.xml +++ b/vector-stores/spring-ai-gemfire-store/pom.xml @@ -36,10 +36,11 @@ git@github.com:spring-projects/spring-ai.git - - 17 - 17 - + + 17 + 17 + false + diff --git a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java index 7f3390d6d2d..753eb1e7bdb 100644 --- a/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java +++ b/vector-stores/spring-ai-gemfire-store/src/main/java/org/springframework/ai/vectorstore/GemFireVectorStore.java @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -52,9 +52,6 @@ import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.util.UriComponentsBuilder; -import static org.springframework.http.HttpStatus.BAD_REQUEST; -import static org.springframework.http.HttpStatus.NOT_FOUND; - /** * A VectorStore implementation backed by GemFire. This store supports creating, updating, * deleting, and similarity searching of documents in a GemFire index. @@ -139,37 +136,37 @@ public GemFireVectorStore(GemFireVectorStoreConfig config, EmbeddingModel embedd private String indexName; public String getIndexName() { - return indexName; + return this.indexName; } private int beamWidth; public int getBeamWidth() { - return beamWidth; + return this.beamWidth; } private int maxConnections; public int getMaxConnections() { - return maxConnections; + return this.maxConnections; } private int buckets; public int getBuckets() { - return buckets; + return this.buckets; } private String vectorSimilarityFunction; public String getVectorSimilarityFunction() { - return vectorSimilarityFunction; + return this.vectorSimilarityFunction; } private String[] fields; public String[] getFields() { - return fields; + return this.fields; } // Query Defaults @@ -202,7 +199,150 @@ public boolean indexExists() { } public String getIndex() { - return client.get().uri("/" + indexName).retrieve().bodyToMono(String.class).onErrorReturn("").block(); + return this.client.get() + .uri("/" + this.indexName) + .retrieve() + .bodyToMono(String.class) + .onErrorReturn("") + .block(); + } + + @Override + public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + UploadRequest upload = new UploadRequest(documents.stream() + .map(document -> new UploadRequest.Embedding(document.getId(), document.getEmbedding(), DOCUMENT_FIELD, + document.getContent(), document.getMetadata())) + .toList()); + + String embeddingsJson = null; + try { + String embeddingString = this.objectMapper.writeValueAsString(upload); + embeddingsJson = embeddingString.substring("{\"embeddings\":".length()); + } + catch (JsonProcessingException e) { + throw new RuntimeException(String.format("Embedding JSON parsing error: %s", e.getMessage())); + } + + this.client.post() + .uri("/" + this.indexName + EMBEDDINGS) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(embeddingsJson) + .retrieve() + .bodyToMono(Void.class) + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + @Override + public Optional doDelete(List idList) { + try { + this.client.method(HttpMethod.DELETE) + .uri("/" + this.indexName + EMBEDDINGS) + .body(BodyInserters.fromValue(idList)) + .retrieve() + .bodyToMono(Void.class) + .block(); + } + catch (Exception e) { + logger.warn("Error removing embedding: {}", e.getMessage(), e); + return Optional.of(false); + } + return Optional.of(true); + } + + @Override + public List doSimilaritySearch(SearchRequest request) { + if (request.hasFilterExpression()) { + throw new UnsupportedOperationException("GemFire currently does not support metadata filter expressions."); + } + float[] floatVector = this.embeddingModel.embed(request.getQuery()); + return this.client.post() + .uri("/" + this.indexName + QUERY) + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(new QueryRequest(floatVector, request.getTopK(), request.getTopK(), // TopKPerBucket + true)) + .retrieve() + .bodyToFlux(QueryResponse.class) + .filter(r -> r.score >= request.getSimilarityThreshold()) + .map(r -> { + Map metadata = r.metadata; + if (r.metadata == null) { + metadata = new HashMap<>(); + metadata.put(DOCUMENT_FIELD, "--Deleted--"); + } + metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score); + String content = (String) metadata.remove(DOCUMENT_FIELD); + return new Document(r.key, content, metadata); + }) + .collectList() + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + /** + * Creates a new index in the GemFireVectorStore using specified parameters. This + * method is invoked during initialization. + * @throws JsonProcessingException if an error occurs during JSON processing + */ + public void createIndex() throws JsonProcessingException { + CreateRequest createRequest = new CreateRequest(this.indexName); + createRequest.setBeamWidth(this.beamWidth); + createRequest.setMaxConnections(this.maxConnections); + createRequest.setBuckets(this.buckets); + createRequest.setVectorSimilarityFunction(this.vectorSimilarityFunction); + createRequest.setFields(this.fields); + + String index = this.objectMapper.writeValueAsString(createRequest); + + this.client.post() + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(index) + .retrieve() + .bodyToMono(Void.class) + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + public void deleteIndex() { + DeleteRequest deleteRequest = new DeleteRequest(); + this.client.method(HttpMethod.DELETE) + .uri("/" + this.indexName) + .body(BodyInserters.fromValue(deleteRequest)) + .retrieve() + .bodyToMono(Void.class) + .onErrorMap(WebClientException.class, this::handleHttpClientException) + .block(); + } + + /** + * Handles exceptions that occur during HTTP client operations and maps them to + * appropriate runtime exceptions. + * @param ex the exception that occurred during HTTP client operation + * @return a mapped runtime exception corresponding to the HTTP client exception + */ + private Throwable handleHttpClientException(Throwable ex) { + if (!(ex instanceof WebClientResponseException clientException)) { + throw new RuntimeException(String.format("Got an unexpected error: %s", ex)); + } + + if (clientException.getStatusCode().equals(org.springframework.http.HttpStatus.NOT_FOUND)) { + throw new RuntimeException(String.format("Index %s not found: %s", this.indexName, ex)); + } + else if (clientException.getStatusCode().equals(org.springframework.http.HttpStatus.BAD_REQUEST)) { + throw new RuntimeException(String.format("Bad Request: %s", ex)); + } + else { + throw new RuntimeException(String.format("Got an unexpected HTTP error: %s", ex)); + } + } + + @Override + public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { + return VectorStoreObservationContext.builder(VectorStoreProvider.GEMFIRE.value(), operationName) + .withCollectionName(this.indexName) + .withDimensions(this.embeddingModel.dimensions()) + .withFieldName(EMBEDDINGS); } public static class CreateRequest { @@ -233,7 +373,7 @@ public CreateRequest(String indexName) { } public String getIndexName() { - return indexName; + return this.indexName; } public void setIndexName(String indexName) { @@ -241,7 +381,7 @@ public void setIndexName(String indexName) { } public int getBeamWidth() { - return beamWidth; + return this.beamWidth; } public void setBeamWidth(int beamWidth) { @@ -249,7 +389,7 @@ public void setBeamWidth(int beamWidth) { } public int getMaxConnections() { - return maxConnections; + return this.maxConnections; } public void setMaxConnections(int maxConnections) { @@ -257,7 +397,7 @@ public void setMaxConnections(int maxConnections) { } public String getVectorSimilarityFunction() { - return vectorSimilarityFunction; + return this.vectorSimilarityFunction; } public void setVectorSimilarityFunction(String vectorSimilarityFunction) { @@ -265,7 +405,7 @@ public void setVectorSimilarityFunction(String vectorSimilarityFunction) { } public String[] getFields() { - return fields; + return this.fields; } public void setFields(String[] fields) { @@ -273,7 +413,7 @@ public void setFields(String[] fields) { } public int getBuckets() { - return buckets; + return this.buckets; } public void setBuckets(int buckets) { @@ -287,11 +427,11 @@ private static final class UploadRequest { private final List embeddings; public List getEmbeddings() { - return embeddings; + return this.embeddings; } @JsonCreator - public UploadRequest(@JsonProperty("embeddings") List embeddings) { + UploadRequest(@JsonProperty("embeddings") List embeddings) { this.embeddings = embeddings; } @@ -304,8 +444,8 @@ private static final class Embedding { @JsonInclude(JsonInclude.Include.NON_NULL) private Map metadata; - public Embedding(@JsonProperty("key") String key, @JsonProperty("vector") float[] vector, - String contentName, String content, @JsonProperty("metadata") Map metadata) { + Embedding(@JsonProperty("key") String key, @JsonProperty("vector") float[] vector, String contentName, + String content, @JsonProperty("metadata") Map metadata) { this.key = key; this.vector = vector; this.metadata = new HashMap<>(metadata); @@ -313,15 +453,15 @@ public Embedding(@JsonProperty("key") String key, @JsonProperty("vector") float[ } public String getKey() { - return key; + return this.key; } public float[] getVector() { - return vector; + return this.vector; } public Map getMetadata() { - return metadata; + return this.metadata; } } @@ -343,7 +483,7 @@ private static final class QueryRequest { @JsonProperty("include-metadata") private final boolean includeMetadata; - public QueryRequest(float[] vector, int k, int kPerBucket, boolean includeMetadata) { + QueryRequest(float[] vector, int k, int kPerBucket, boolean includeMetadata) { this.vector = vector; this.k = k; this.kPerBucket = kPerBucket; @@ -351,19 +491,19 @@ public QueryRequest(float[] vector, int k, int kPerBucket, boolean includeMetada } public float[] getVector() { - return vector; + return this.vector; } public int getK() { - return k; + return this.k; } public int getkPerBucket() { - return kPerBucket; + return this.kPerBucket; } public boolean isIncludeMetadata() { - return includeMetadata; + return this.includeMetadata; } } @@ -377,7 +517,7 @@ private static final class QueryResponse { private Map metadata; private String getContent(String field) { - return (String) metadata.get(field); + return (String) this.metadata.get(field); } public void setKey(String key) { @@ -399,15 +539,15 @@ private static class DeleteRequest { @JsonProperty("delete-data") private boolean deleteData = true; - public DeleteRequest() { + DeleteRequest() { } - public DeleteRequest(boolean deleteData) { + DeleteRequest(boolean deleteData) { this.deleteData = deleteData; } public boolean isDeleteData() { - return deleteData; + return this.deleteData; } public void setDeleteData(boolean deleteData) { @@ -416,145 +556,7 @@ public void setDeleteData(boolean deleteData) { } - @Override - public void doAdd(List documents) { - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - UploadRequest upload = new UploadRequest(documents.stream() - .map(document -> new UploadRequest.Embedding(document.getId(), document.getEmbedding(), DOCUMENT_FIELD, - document.getContent(), document.getMetadata())) - .toList()); - - String embeddingsJson = null; - try { - String embeddingString = this.objectMapper.writeValueAsString(upload); - embeddingsJson = embeddingString.substring("{\"embeddings\":".length()); - } - catch (JsonProcessingException e) { - throw new RuntimeException(String.format("Embedding JSON parsing error: %s", e.getMessage())); - } - - client.post() - .uri("/" + indexName + EMBEDDINGS) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(embeddingsJson) - .retrieve() - .bodyToMono(Void.class) - .onErrorMap(WebClientException.class, this::handleHttpClientException) - .block(); - } - - @Override - public Optional doDelete(List idList) { - try { - client.method(HttpMethod.DELETE) - .uri("/" + indexName + EMBEDDINGS) - .body(BodyInserters.fromValue(idList)) - .retrieve() - .bodyToMono(Void.class) - .block(); - } - catch (Exception e) { - logger.warn("Error removing embedding: {}", e.getMessage(), e); - return Optional.of(false); - } - return Optional.of(true); - } - - @Override - public List doSimilaritySearch(SearchRequest request) { - if (request.hasFilterExpression()) { - throw new UnsupportedOperationException("GemFire currently does not support metadata filter expressions."); - } - float[] floatVector = this.embeddingModel.embed(request.getQuery()); - return client.post() - .uri("/" + indexName + QUERY) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(new QueryRequest(floatVector, request.getTopK(), request.getTopK(), // TopKPerBucket - true)) - .retrieve() - .bodyToFlux(QueryResponse.class) - .filter(r -> r.score >= request.getSimilarityThreshold()) - .map(r -> { - Map metadata = r.metadata; - if (r.metadata == null) { - metadata = new HashMap<>(); - metadata.put(DOCUMENT_FIELD, "--Deleted--"); - } - metadata.put(DISTANCE_METADATA_FIELD_NAME, 1 - r.score); - String content = (String) metadata.remove(DOCUMENT_FIELD); - return new Document(r.key, content, metadata); - }) - .collectList() - .onErrorMap(WebClientException.class, this::handleHttpClientException) - .block(); - } - - /** - * Creates a new index in the GemFireVectorStore using specified parameters. This - * method is invoked during initialization. - * @throws JsonProcessingException if an error occurs during JSON processing - */ - public void createIndex() throws JsonProcessingException { - CreateRequest createRequest = new CreateRequest(indexName); - createRequest.setBeamWidth(beamWidth); - createRequest.setMaxConnections(maxConnections); - createRequest.setBuckets(buckets); - createRequest.setVectorSimilarityFunction(vectorSimilarityFunction); - createRequest.setFields(fields); - - String index = this.objectMapper.writeValueAsString(createRequest); - - client.post() - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(index) - .retrieve() - .bodyToMono(Void.class) - .onErrorMap(WebClientException.class, this::handleHttpClientException) - .block(); - } - - public void deleteIndex() { - DeleteRequest deleteRequest = new DeleteRequest(); - client.method(HttpMethod.DELETE) - .uri("/" + indexName) - .body(BodyInserters.fromValue(deleteRequest)) - .retrieve() - .bodyToMono(Void.class) - .onErrorMap(WebClientException.class, this::handleHttpClientException) - .block(); - } - - /** - * Handles exceptions that occur during HTTP client operations and maps them to - * appropriate runtime exceptions. - * @param ex the exception that occurred during HTTP client operation - * @return a mapped runtime exception corresponding to the HTTP client exception - */ - private Throwable handleHttpClientException(Throwable ex) { - if (!(ex instanceof WebClientResponseException clientException)) { - throw new RuntimeException(String.format("Got an unexpected error: %s", ex)); - } - - if (clientException.getStatusCode().equals(NOT_FOUND)) { - throw new RuntimeException(String.format("Index %s not found: %s", indexName, ex)); - } - else if (clientException.getStatusCode().equals(BAD_REQUEST)) { - throw new RuntimeException(String.format("Bad Request: %s", ex)); - } - else { - throw new RuntimeException(String.format("Got an unexpected HTTP error: %s", ex)); - } - } - - @Override - public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) { - return VectorStoreObservationContext.builder(VectorStoreProvider.GEMFIRE.value(), operationName) - .withCollectionName(this.indexName) - .withDimensions(this.embeddingModel.dimensions()) - .withFieldName(EMBEDDINGS); - } - - public static class GemFireVectorStoreConfig { + public static final class GemFireVectorStoreConfig { // Create Index DEFAULT Values public static final String DEFAULT_HOST = "localhost"; diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java index 3d204767a50..f267634a249 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class GemFireImage { +public final class GemFireImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("gemfire/gemfire-all:10.1-jdk17"); + private GemFireImage() { + + } + } diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java index 86c71724a66..972352f6a31 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreIT.java @@ -41,7 +41,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import static java.util.concurrent.TimeUnit.MINUTES; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; @@ -110,7 +109,7 @@ public void addAndDeleteEmbeddingTest() { vectorStore.add(this.documents); vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); Awaitility.await() - .atMost(1, MINUTES) + .atMost(1, java.util.concurrent.TimeUnit.MINUTES) .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(3)), hasSize(0)); }); @@ -123,7 +122,7 @@ public void addAndSearchTest() { vectorStore.add(this.documents); Awaitility.await() - .atMost(1, MINUTES) + .atMost(1, java.util.concurrent.TimeUnit.MINUTES) .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), hasSize(1)); @@ -147,7 +146,7 @@ public void documentUpdateTest() { vectorStore.add(List.of(document)); SearchRequest springSearchRequest = SearchRequest.query("Spring").withTopK(5); Awaitility.await() - .atMost(1, MINUTES) + .atMost(1, java.util.concurrent.TimeUnit.MINUTES) .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), hasSize(1)); List results = vectorStore.similaritySearch(springSearchRequest); @@ -182,7 +181,7 @@ public void searchThresholdTest() { vectorStore.add(this.documents); Awaitility.await() - .atMost(1, MINUTES) + .atMost(1, java.util.concurrent.TimeUnit.MINUTES) .until(() -> vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll()), hasSize(3)); diff --git a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java index abf2374c068..47e924a24e3 100644 --- a/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-gemfire-store/src/test/java/org/springframework/ai/vectorstore/GemFireVectorStoreObservationIT.java @@ -48,7 +48,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; -import static java.util.concurrent.TimeUnit.MINUTES; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.hasSize; @@ -145,7 +144,7 @@ void observationVectorStoreAddAndQueryOperations() { .hasBeenStopped(); Awaitility.await() - .atMost(1, MINUTES) + .atMost(1, java.util.concurrent.TimeUnit.MINUTES) .until(() -> vectorStore .similaritySearch(SearchRequest.query("Great Depression").withTopK(5).withSimilarityThresholdAll()), hasSize(3)); diff --git a/vector-stores/spring-ai-hanadb-store/pom.xml b/vector-stores/spring-ai-hanadb-store/pom.xml index b794e916490..0e72805977f 100644 --- a/vector-stores/spring-ai-hanadb-store/pom.xml +++ b/vector-stores/spring-ai-hanadb-store/pom.xml @@ -37,6 +37,12 @@ git@github.com:spring-projects/spring-ai.git + + 17 + 17 + false + + org.springframework.ai @@ -94,4 +100,4 @@ test - \ No newline at end of file + diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java index 89fbf5e273c..a0d4ddc9c60 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStore.java @@ -1,11 +1,11 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.vectorstore; import java.util.Collections; @@ -110,27 +111,27 @@ public void doAdd(List documents) { document.getId()); String content = document.getContent().replaceAll("\\s+", " "); String embedding = getEmbedding(document); - repository.save(config.getTableName(), document.getId(), embedding, content); + this.repository.save(this.config.getTableName(), document.getId(), embedding, content); } logger.info("Embeddings saved in HanaCloudVectorStore for {} documents", count - 1); } @Override public Optional doDelete(List idList) { - int deleteCount = repository.deleteEmbeddingsById(config.getTableName(), idList); + int deleteCount = this.repository.deleteEmbeddingsById(this.config.getTableName(), idList); logger.info("{} embeddings deleted", deleteCount); return Optional.of(deleteCount == idList.size()); } public int purgeEmbeddings() { - int deleteCount = repository.deleteAllEmbeddings(config.getTableName()); + int deleteCount = this.repository.deleteAllEmbeddings(this.config.getTableName()); logger.info("{} embeddings deleted", deleteCount); return deleteCount; } @Override public List similaritySearch(String query) { - return similaritySearch(SearchRequest.query(query).withTopK(config.getTopK())); + return similaritySearch(SearchRequest.query(query).withTopK(this.config.getTopK())); } @Override @@ -141,8 +142,8 @@ public List doSimilaritySearch(SearchRequest request) { } String queryEmbedding = getEmbedding(request); - List searchResult = repository.cosineSimilaritySearch(config.getTableName(), - request.getTopK(), queryEmbedding); + List searchResult = this.repository + .cosineSimilaritySearch(this.config.getTableName(), request.getTopK(), queryEmbedding); logger.info("Hana cosine-similarity for query={}, with topK={} returned {} results", request.getQuery(), request.getTopK(), searchResult.size()); diff --git a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java index b8b8faff00d..bc56fe9a13f 100644 --- a/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java +++ b/vector-stores/spring-ai-hanadb-store/src/main/java/org/springframework/ai/vectorstore/HanaCloudVectorStoreConfig.java @@ -24,7 +24,7 @@ * @author Rahul Mittal * @since 1.0.0 */ -public class HanaCloudVectorStoreConfig { +public final class HanaCloudVectorStoreConfig { private String tableName; diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java index 48d97dd2680..3a10995323b 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/CricketWorldCupHanaController.java @@ -82,7 +82,7 @@ public ResponseEntity handleFileUpload(@RequestParam("pdf") MultipartFil } @GetMapping("/ai/hana-vector-store/cricket-world-cup") - public Map hanaVectorStoreSearch(@RequestParam(value = "message") String message) { + public Map hanaVectorStoreSearch(@RequestParam("message") String message) { var documents = this.hanaCloudVectorStore.similaritySearch(message); var inlined = documents.stream().map(Document::getContent).collect(Collectors.joining(System.lineSeparator())); var similarDocsMessage = new SystemPromptTemplate("Based on the following: {documents}") diff --git a/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties b/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties index f2d9b9274ad..da05975c9c7 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties +++ b/vector-stores/spring-ai-hanadb-store/src/test/resources/application.properties @@ -1,28 +1,10 @@ -# -# Copyright 2023-2024 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - spring.ai.openai.api-key=${OPENAI_API_KEY} spring.ai.openai.embedding.options.model=text-embedding-ada-002 - - spring.datasource.driver-class-name=com.sap.db.jdbc.Driver spring.datasource.url=${HANA_DATASOURCE_URL} spring.datasource.username=${HANA_DATASOURCE_USERNAME} spring.datasource.password=${HANA_DATASOURCE_PASSWORD} spring.ai.vectorstore.hanadb.tableName=CRICKET_WORLD_CUP -spring.ai.vectorstore.hanadb.topK=3 \ No newline at end of file +spring.ai.vectorstore.hanadb.topK=3 diff --git a/vector-stores/spring-ai-milvus-store/pom.xml b/vector-stores/spring-ai-milvus-store/pom.xml index bdf72f777ab..1a0029d032e 100644 --- a/vector-stores/spring-ai-milvus-store/pom.xml +++ b/vector-stores/spring-ai-milvus-store/pom.xml @@ -36,6 +36,12 @@ git@github.com:spring-projects/spring-ai.git + + 17 + 17 + false + + org.springframework.ai diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java index c7e9a093999..fd8cde2d49a 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusFilterExpressionConverter.java @@ -75,4 +75,4 @@ protected void doKey(Key key, StringBuilder context) { context.append("metadata[\"" + identifier + "\"]"); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java index fe91142344c..27e743dd133 100644 --- a/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java +++ b/vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/MilvusVectorStore.java @@ -429,7 +429,7 @@ private String getSimilarityMetric() { /** * Configuration for the Milvus vector store. */ - public static class MilvusVectorStoreConfig { + public static final class MilvusVectorStoreConfig { private final String databaseName; @@ -468,7 +468,7 @@ public static MilvusVectorStoreConfig defaultConfig() { return builder().build(); } - public static class Builder { + public static final class Builder { private String databaseName = DEFAULT_DATABASE_NAME; diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java index bc60af7fa23..0e0dbf5b586 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusEmbeddingDimensionsTests.java @@ -31,10 +31,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.only; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -68,7 +68,7 @@ public void explicitlySetDimensions() { @Test public void embeddingModelDimensions() { - when(this.embeddingModel.dimensions()).thenReturn(969); + given(this.embeddingModel.dimensions()).willReturn(969); MilvusVectorStoreConfig config = MilvusVectorStoreConfig.builder().build(); @@ -84,7 +84,7 @@ public void embeddingModelDimensions() { @Test public void fallBackToDefaultDimensions() { - when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); + given(this.embeddingModel.dimensions()).willThrow(new RuntimeException()); var dim = new MilvusVectorStore(this.milvusClient, this.embeddingModel, MilvusVectorStoreConfig.builder().build(), true, new TokenCountBatchingStrategy()) diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java index 8212474bd77..f06b6667790 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class MilvusImage { +public final class MilvusImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("milvusdb/milvus:v2.4.9"); + private MilvusImage() { + + } + } diff --git a/vector-stores/spring-ai-mongodb-atlas-store/pom.xml b/vector-stores/spring-ai-mongodb-atlas-store/pom.xml index 48689d36d8e..753b5f9c8e5 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/pom.xml +++ b/vector-stores/spring-ai-mongodb-atlas-store/pom.xml @@ -35,6 +35,11 @@ git://github.com/spring-projects/spring-ai.git git@github.com:spring-projects/spring-ai.git + + 17 + 17 + false + diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java index 0cb9f974bbd..b2d888ee65c 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/main/java/org/springframework/ai/vectorstore/MongoDBAtlasVectorStore.java @@ -43,8 +43,6 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.util.Assert; -import static org.springframework.data.mongodb.core.query.Criteria.where; - /** * @author Chris Smith * @author Soby Chacko @@ -190,7 +188,7 @@ public void doAdd(List documents) { @Override public Optional doDelete(List idList) { - Query query = new Query(where(ID_FIELD_NAME).in(idList)); + Query query = new Query(org.springframework.data.mongodb.core.query.Criteria.where(ID_FIELD_NAME).in(idList)); var deleteRes = this.mongoTemplate.remove(query, this.config.collectionName); long deleteCount = deleteRes.getDeletedCount(); @@ -236,7 +234,7 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str .withFieldName(this.config.pathName); } - public static class MongoDBVectorStoreConfig { + public static final class MongoDBVectorStoreConfig { private final String collectionName; @@ -264,7 +262,7 @@ public static MongoDBVectorStoreConfig defaultConfig() { return builder().build(); } - public static class Builder { + public static final class Builder { private String collectionName = DEFAULT_VECTOR_COLLECTION_NAME; diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java index 946e813e4ce..9a07309bcc8 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/MongoDbImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class MongoDbImage { +public final class MongoDbImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("mongodb/mongodb-atlas-local:8.0.0"); + private MongoDbImage() { + + } + } diff --git a/vector-stores/spring-ai-neo4j-store/pom.xml b/vector-stores/spring-ai-neo4j-store/pom.xml index 913277180f1..29697aaeb3b 100644 --- a/vector-stores/spring-ai-neo4j-store/pom.xml +++ b/vector-stores/spring-ai-neo4j-store/pom.xml @@ -36,6 +36,12 @@ git@github.com:spring-projects/spring-ai.git + + 17 + 17 + false + + diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java index 6d938a80194..21f2e8c8a6e 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/Neo4jVectorStore.java @@ -132,7 +132,7 @@ public Optional doDelete(List idList) { .run(""" MATCH (n:%s) WHERE n.%s IN $ids CALL { WITH n DETACH DELETE n } IN TRANSACTIONS OF $transactionSize ROWS - """.formatted(this.config.label, this.config.idProperty), + """.formatted(this.config.label, this.config.idProperty), Map.of("ids", idList, "transactionSize", 10_000)) .consume(); return Optional.of(idList.size() == summary.counters().nodesDeleted()); @@ -182,8 +182,8 @@ public void afterPropertiesSet() { var statement = """ CREATE VECTOR INDEX %s IF NOT EXISTS FOR (n:%s) ON (n.%s) OPTIONS {indexConfig: { - `vector.dimensions`: %d, - `vector.similarity_function`: '%s' + `vector.dimensions`: %d, + `vector.similarity_function`: '%s' }} """.formatted(this.config.indexName, this.config.label, this.config.embeddingProperty, this.config.embeddingDimension, this.config.distanceType.name); @@ -312,7 +312,7 @@ public static Neo4jVectorStoreConfig defaultConfig() { return builder().build(); } - public static class Builder { + public static final class Builder { private String databaseName; diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java index 7321699998d..747c0807a5a 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/filter/Neo4jVectorFilterExpressionConverter.java @@ -88,4 +88,4 @@ protected void doEndGroup(Group group, StringBuilder context) { context.append(")"); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java index 513fd69433b..1773f4eb891 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jImage.java @@ -21,10 +21,14 @@ /** * @author Thomas Vitale */ -public class Neo4jImage { +public final class Neo4jImage { // Needs to be Neo4j 5.15+ because Neo4j 5.15 deprecated the old vector index creation // function. public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("neo4j:5.24"); + private Neo4jImage() { + + } + } diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java index 3aa43823913..a707a30622a 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/Neo4jVectorStoreIT.java @@ -257,32 +257,28 @@ void searchThresholdTest() { @Test void ensureVectorIndexGetsCreated() { - this.contextRunner.run(context -> { - assertThat(context.getBean(Driver.class) - .executableQuery( - "SHOW indexes yield name, type WHERE name = 'spring-ai-document-index' AND type = 'VECTOR' return count(*) > 0") - .execute() - .records() - .get(0) // get first record - .get(0) - .asBoolean()) // get returned result - .isTrue(); - }); + this.contextRunner.run(context -> assertThat(context.getBean(Driver.class) + .executableQuery( + "SHOW indexes yield name, type WHERE name = 'spring-ai-document-index' AND type = 'VECTOR' return count(*) > 0") + .execute() + .records() + .get(0) // get first record + .get(0) + .asBoolean()) // get returned result + .isTrue()); } @Test void ensureIdIndexGetsCreated() { - this.contextRunner.run(context -> { - assertThat(context.getBean(Driver.class) - .executableQuery( - "SHOW indexes yield labelsOrTypes, properties, type WHERE any(x in labelsOrTypes where x = 'Document') AND any(x in properties where x = 'id') AND type = 'RANGE' return count(*) > 0") - .execute() - .records() - .get(0) // get first record - .get(0) - .asBoolean()) // get returned result - .isTrue(); - }); + this.contextRunner.run(context -> assertThat(context.getBean(Driver.class) + .executableQuery( + "SHOW indexes yield labelsOrTypes, properties, type WHERE any(x in labelsOrTypes where x = 'Document') AND any(x in properties where x = 'id') AND type = 'RANGE' return count(*) > 0") + .execute() + .records() + .get(0) // get first record + .get(0) + .asBoolean()) // get returned result + .isTrue()); } @SpringBootConfiguration diff --git a/vector-stores/spring-ai-opensearch-store/pom.xml b/vector-stores/spring-ai-opensearch-store/pom.xml index aa3ddfbbd09..e8c5972f268 100644 --- a/vector-stores/spring-ai-opensearch-store/pom.xml +++ b/vector-stores/spring-ai-opensearch-store/pom.xml @@ -39,6 +39,7 @@ 4.0.3 + false diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java index 98876f5e645..23330b280a6 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchAiSearchFilterExpressionConverter.java @@ -148,4 +148,4 @@ public void doEndGroup(Filter.Group group, StringBuilder context) { context.append(")"); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java index f81b108ccbd..35710cf9d80 100644 --- a/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java +++ b/vector-stores/spring-ai-opensearch-store/src/main/java/org/springframework/ai/vectorstore/OpenSearchVectorStore.java @@ -71,12 +71,12 @@ public class OpenSearchVectorStore extends AbstractObservationVectorStore implem public static final String DEFAULT_MAPPING_EMBEDDING_TYPE_KNN_VECTOR_DIMENSION_1536 = """ { - "properties":{ - "embedding":{ - "type":"knn_vector", - "dimension":1536 - } - } + "properties":{ + "embedding":{ + "type":"knn_vector", + "dimension":1536 + } + } } """; @@ -154,8 +154,9 @@ public void doAdd(List documents) { @Override public Optional doDelete(List idList) { BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder(); - for (String id : idList) + for (String id : idList) { bulkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.index).id(id))); + } return Optional.of(bulkRequest(bulkRequestBuilder.build()).errors()); } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java index 294ed87d5c1..42c8d9b9c39 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/OpenSearchImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class OpenSearchImage { +public final class OpenSearchImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("opensearchproject/opensearch:2.17.1"); + private OpenSearchImage() { + + } + } diff --git a/vector-stores/spring-ai-oracle-store/pom.xml b/vector-stores/spring-ai-oracle-store/pom.xml index 7335b4d8833..3e56973eb57 100644 --- a/vector-stores/spring-ai-oracle-store/pom.xml +++ b/vector-stores/spring-ai-oracle-store/pom.xml @@ -36,6 +36,12 @@ git@github.com:spring-projects/spring-ai.git + + 17 + 17 + false + + org.springframework.ai diff --git a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java index a32a904eb87..a9e3b63eb79 100644 --- a/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java +++ b/vector-stores/spring-ai-oracle-store/src/main/java/org/springframework/ai/vectorstore/OracleVectorStore.java @@ -56,9 +56,6 @@ import org.springframework.jdbc.core.RowMapper; import org.springframework.util.StringUtils; -import static org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT; -import static org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue; - /** *

* Integration of Oracle database 23ai as a Vector Store. @@ -217,10 +214,13 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { final byte[] json = toJson(document.getMetadata()); final VECTOR embeddingVector = toVECTOR(document.getEmbedding()); - setParameterValue(ps, 1, Types.VARCHAR, document.getId()); - setParameterValue(ps, 2, Types.VARCHAR, content); - setParameterValue(ps, 3, OracleType.JSON.getVendorTypeNumber(), json); - setParameterValue(ps, 4, OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); + org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue(ps, 1, Types.VARCHAR, + document.getId()); + org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue(ps, 2, Types.VARCHAR, content); + org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue(ps, 3, + OracleType.JSON.getVendorTypeNumber(), json); + org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue(ps, 4, + OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); } @Override @@ -357,7 +357,8 @@ public List doSimilaritySearch(SearchRequest request) { @Override public void setValues(PreparedStatement ps, int i) throws SQLException { - setParameterValue(ps, 1, OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); + org.springframework.jdbc.core.StatementCreatorUtils.setParameterValue(ps, 1, + OracleType.VECTOR.getVendorTypeNumber(), embeddingVector); } @Override @@ -381,17 +382,25 @@ public int getBatchSize() { select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance from %s %sorder by distance - fetch first %d rows only""", this.distanceType == DOT ? "(1+" : "", this.distanceType.name(), - this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, request.getTopK()) + fetch first %d rows only""", + this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? "(1+" : "", + this.distanceType.name(), + this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? ")/2" : "", + this.tableName, jsonPathFilter, request.getTopK()) : String.format( """ select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance from %s %sorder by distance fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", - this.distanceType == DOT ? "(1+" : "", this.distanceType.name(), - this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, - request.getTopK(), this.searchAccuracy); + this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? "(1+" : "", + this.distanceType.name(), + this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? ")/2" : "", + this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy); logger.debug("SQL query: " + sql); @@ -406,60 +415,69 @@ else if (request.getSimilarityThreshold() == SIMILARITY_THRESHOLD_EXACT_MATCH) { select id, content, metadata, embedding, %sVECTOR_DISTANCE(embedding, ?, %s)%s as distance from %s %sorder by distance - fetch EXACT first %d rows only""", this.distanceType == DOT ? "(1+" : "", - this.distanceType.name(), this.distanceType == DOT ? ")/2" : "", this.tableName, jsonPathFilter, - request.getTopK()); + fetch EXACT first %d rows only""", + this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? "(1+" : "", + this.distanceType.name(), + this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? ")/2" : "", + this.tableName, jsonPathFilter, request.getTopK()); logger.debug("SQL query: " + sql); return this.jdbcTemplate.query(sql, new DocumentRowMapper(), embeddingVector); } else { - if (!this.forcedNormalization - || (this.distanceType != OracleVectorStoreDistanceType.COSINE && this.distanceType != DOT)) { + if (!this.forcedNormalization || (this.distanceType != OracleVectorStoreDistanceType.COSINE + && this.distanceType != org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT)) { throw new RuntimeException( "Similarity threshold filtering requires all vectors to be normalized, see the forcedNormalization parameter for this Vector store. Also only COSINE and DOT distance types are supported."); } - final double distance = this.distanceType == DOT ? (1d - request.getSimilarityThreshold()) * 2d - 1d - : 1d - request.getSimilarityThreshold(); + final double distance = this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? (1d - request.getSimilarityThreshold()) * 2d - 1d : 1d - request.getSimilarityThreshold(); if (StringUtils.hasText(nativeFilterExpression)) { jsonPathFilter = String.format(" and JSON_EXISTS( metadata, '%s' )", nativeFilterExpression); } - final String sql = this.distanceType == DOT ? (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY - ? String.format( - """ - select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance - from %s - where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s - order by distance - fetch first %d rows only""", - this.tableName, jsonPathFilter, request.getTopK()) - : String.format( - """ - select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance - from %s - where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s - order by distance - fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", - this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy) - - ) : (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? String.format(""" - select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance - from %s - where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s - order by distance - fetch first %d rows only""", this.tableName, jsonPathFilter, request.getTopK()) - : String.format( - """ - select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance - from %s - where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s - order by distance - fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", - this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy)); + final String sql = this.distanceType == org.springframework.ai.vectorstore.OracleVectorStore.OracleVectorStoreDistanceType.DOT + ? (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY + ? String.format( + """ + select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance + from %s + where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s + order by distance + fetch first %d rows only""", + this.tableName, jsonPathFilter, request.getTopK()) + : String.format( + """ + select id, content, metadata, embedding, (1+VECTOR_DISTANCE(embedding, ?, DOT))/2 as distance + from %s + where VECTOR_DISTANCE(embedding, ?, DOT) <= ?%s + order by distance + fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", + this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy) + + ) + : (this.searchAccuracy == DEFAULT_SEARCH_ACCURACY + ? String.format( + """ + select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance + from %s + where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s + order by distance + fetch first %d rows only""", + this.tableName, jsonPathFilter, request.getTopK()) + : String.format( + """ + select id, content, metadata, embedding, VECTOR_DISTANCE(embedding, ?, COSINE) as distance + from %s + where VECTOR_DISTANCE(embedding, ?, COSINE) <= ?%s + order by distance + fetch APPROXIMATE first %d rows only WITH TARGET ACCURACY %d""", + this.tableName, jsonPathFilter, request.getTopK(), this.searchAccuracy)); logger.debug("SQL query: " + sql); @@ -503,10 +521,10 @@ embedding vector(%s,FLOAT64) annotations(Distance '%s') this.jdbcTemplate.execute(String.format(""" create vector index if not exists vector_index_%s on %s (embedding) organization neighbor partitions - distance %s - with target accuracy %d - parameters (type IVF, neighbor partitions 10)""", this.tableName, - this.tableName, this.distanceType.name(), + distance %s + with target accuracy %d + parameters (type IVF, neighbor partitions 10)""", this.tableName, this.tableName, + this.distanceType.name(), this.searchAccuracy == DEFAULT_SEARCH_ACCURACY ? 95 : this.searchAccuracy)); break; @@ -585,7 +603,7 @@ public enum OracleVectorStoreIndexType { * "https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/understand-inverted-file-flat-vector-indexes.html">Oracle * Database documentation */ - IVF; + IVF } diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java index 6955a660680..85816879f4d 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class OracleImage { +public final class OracleImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("gvenzl/oracle-free:23-slim"); + private OracleImage() { + + } + } diff --git a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java index 638db84b57f..00c9a4302d0 100644 --- a/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java +++ b/vector-stores/spring-ai-oracle-store/src/test/java/org/springframework/ai/vectorstore/OracleVectorStoreIT.java @@ -55,7 +55,6 @@ import org.springframework.util.CollectionUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY; @Testcontainers public class OracleVectorStoreIT { @@ -119,7 +118,8 @@ private static boolean isSortedByDistance(final List documents) { @ValueSource(strings = { "COSINE", "DOT", "EUCLIDEAN", "EUCLIDEAN_SQUARED", "MANHATTAN" }) public void addAndSearch(String distanceType) { this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) - .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) + .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + + org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -224,7 +224,8 @@ public void searchWithFilters(String distanceType, int searchAccuracy) { @ValueSource(strings = { "COSINE", "DOT", "EUCLIDEAN", "EUCLIDEAN_SQUARED", "MANHATTAN" }) public void documentUpdate(String distanceType) { this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) - .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) + .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + + org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); @@ -263,7 +264,8 @@ public void documentUpdate(String distanceType) { @ValueSource(strings = { "COSINE", "DOT" }) public void searchWithThreshold(String distanceType) { this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.oracle.distanceType=" + distanceType) - .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + DEFAULT_SEARCH_ACCURACY) + .withPropertyValues("test.spring.ai.vectorstore.oracle.searchAccuracy=" + + org.springframework.ai.vectorstore.OracleVectorStore.DEFAULT_SEARCH_ACCURACY) .run(context -> { VectorStore vectorStore = context.getBean(VectorStore.class); diff --git a/vector-stores/spring-ai-pgvector-store/pom.xml b/vector-stores/spring-ai-pgvector-store/pom.xml index 7f3f5341623..c36221d7e49 100644 --- a/vector-stores/spring-ai-pgvector-store/pom.xml +++ b/vector-stores/spring-ai-pgvector-store/pom.xml @@ -36,6 +36,12 @@ git@github.com:spring-projects/spring-ai.git + + 17 + 17 + false + + org.springframework.ai diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java index 06db63670c7..814de5b4c49 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java @@ -115,4 +115,4 @@ protected void doEndGroup(Group group, StringBuilder context) { context.append(")"); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 56bb866e61b..69cd79b41c4 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -489,7 +489,7 @@ private static class DocumentRowMapper implements RowMapper { private final ObjectMapper objectMapper; - public DocumentRowMapper(ObjectMapper objectMapper) { + DocumentRowMapper(ObjectMapper objectMapper) { this.objectMapper = objectMapper; } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java index efef6d9135e..587b2cb8e02 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorEmbeddingDimensionsTests.java @@ -25,10 +25,10 @@ import org.springframework.jdbc.core.JdbcTemplate; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.never; import static org.mockito.Mockito.only; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Christian Tzolov @@ -55,7 +55,7 @@ public void explicitlySetDimensions() { @Test public void embeddingModelDimensions() { - when(this.embeddingModel.dimensions()).thenReturn(969); + given(this.embeddingModel.dimensions()).willReturn(969); var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); @@ -67,7 +67,7 @@ public void embeddingModelDimensions() { @Test public void fallBackToDefaultDimensions() { - when(this.embeddingModel.dimensions()).thenThrow(new RuntimeException()); + given(this.embeddingModel.dimensions()).willThrow(new RuntimeException()); var dim = new PgVectorStore(this.jdbcTemplate, this.embeddingModel).embeddingDimensions(); diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java index 0df031e63cd..2dead75a532 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class PgVectorImage { +public final class PgVectorImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("pgvector/pgvector:pg17"); + private PgVectorImage() { + + } + } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index 8405d32337f..3366c88296b 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -122,7 +122,8 @@ private static boolean isSortedByDistance(List docs) { } Iterator iter = distances.iterator(); - Float current, previous = iter.next(); + Float current; + Float previous = iter.next(); while (iter.hasNext()) { current = iter.next(); if (previous > current) { diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java index abde63cfe84..9a66e35518b 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java @@ -44,9 +44,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author Fabian Krüger @@ -71,7 +71,7 @@ class PgVectorStoreWithChatMemoryAdvisorIT { Why don't scientists trust atoms? Because they make up everything! """)))); - when(chatModel.call(argumentCaptor.capture())).thenReturn(chatResponse); + given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse); return chatModel; } @@ -153,7 +153,7 @@ void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { documents.forEach(d -> d.setEmbedding(this.embed)); return List.of(this.embed, this.embed); }).when(embeddingModel).embed(ArgumentMatchers.any(), any(), any()); - when(embeddingModel.embed(any(String.class))).thenReturn(this.embed); + given(embeddingModel.embed(any(String.class))).willReturn(this.embed); return embeddingModel; } diff --git a/vector-stores/spring-ai-pinecone-store/pom.xml b/vector-stores/spring-ai-pinecone-store/pom.xml index 87b2722c561..8603595fb88 100644 --- a/vector-stores/spring-ai-pinecone-store/pom.xml +++ b/vector-stores/spring-ai-pinecone-store/pom.xml @@ -38,6 +38,7 @@ 17 17 + false diff --git a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java index f4243653b07..42413456095 100644 --- a/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java +++ b/vector-stores/spring-ai-pinecone-store/src/main/java/org/springframework/ai/vectorstore/PineconeVectorStore.java @@ -343,7 +343,7 @@ public static PineconeVectorStoreConfig defaultConfig() { return builder().build(); } - public static class Builder { + public static final class Builder { private String apiKey; diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java index 917abaf4745..6e1eddc5f32 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java @@ -97,9 +97,9 @@ public void addAndSearchTest() { vectorStore.add(this.documents); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)), + hasSize(1)); List results = vectorStore.similaritySearch(SearchRequest.query("Great Depression").withTopK(1)); @@ -114,9 +114,8 @@ public void addAndSearchTest() { // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); }); } @@ -139,9 +138,7 @@ public void addAndSearchWithFilters() { SearchRequest searchRequest = SearchRequest.query("The World"); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(searchRequest.withTopK(1)); - }, hasSize(1)); + Awaitility.await().until(() -> vectorStore.similaritySearch(searchRequest.withTopK(1)), hasSize(1)); List results = vectorStore.similaritySearch(searchRequest.withTopK(5)); assertThat(results).hasSize(2); @@ -168,9 +165,7 @@ public void addAndSearchWithFilters() { // Remove all documents from the store vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(searchRequest.withTopK(1)); - }, hasSize(0)); + Awaitility.await().until(() -> vectorStore.similaritySearch(searchRequest.withTopK(1)), hasSize(0)); }); } @@ -189,9 +184,7 @@ public void documentUpdateTest() { SearchRequest springSearchRequest = SearchRequest.query("Spring").withTopK(5); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(springSearchRequest); - }, hasSize(1)); + Awaitility.await().until(() -> vectorStore.similaritySearch(springSearchRequest), hasSize(1)); List results = vectorStore.similaritySearch(springSearchRequest); @@ -210,9 +203,9 @@ public void documentUpdateTest() { SearchRequest fooBarSearchRequest = SearchRequest.query("FooBar").withTopK(5); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(); - }, equalTo("The World is Big and Salvation Lurks Around the Corner")); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(fooBarSearchRequest).get(0).getContent(), + equalTo("The World is Big and Salvation Lurks Around the Corner")); results = vectorStore.similaritySearch(fooBarSearchRequest); @@ -225,9 +218,7 @@ public void documentUpdateTest() { // Remove all documents from the store vectorStore.delete(List.of(document.getId())); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(fooBarSearchRequest); - }, hasSize(0)); + Awaitility.await().until(() -> vectorStore.similaritySearch(fooBarSearchRequest), hasSize(0)); }); } @@ -241,10 +232,10 @@ public void searchThresholdTest() { vectorStore.add(this.documents); - Awaitility.await().until(() -> { - return vectorStore - .similaritySearch(SearchRequest.query("Depression").withTopK(50).withSimilarityThresholdAll()); - }, hasSize(3)); + Awaitility.await() + .until(() -> vectorStore + .similaritySearch(SearchRequest.query("Depression").withTopK(50).withSimilarityThresholdAll()), + hasSize(3)); List fullResult = vectorStore .similaritySearch(SearchRequest.query("Depression").withTopK(5).withSimilarityThresholdAll()); @@ -267,9 +258,8 @@ public void searchThresholdTest() { // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); }); } diff --git a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java index b8a84de02c4..53ced5aa0b1 100644 --- a/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-pinecone-store/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreObservationIT.java @@ -128,9 +128,9 @@ void observationVectorStoreAddAndQueryOperations() { .hasBeenStarted() .hasBeenStopped(); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)); - }, hasSize(1)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1)), + hasSize(1)); observationRegistry.clear(); @@ -168,9 +168,8 @@ void observationVectorStoreAddAndQueryOperations() { // Remove all documents from the store vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList()); - Awaitility.await().until(() -> { - return vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)); - }, hasSize(0)); + Awaitility.await() + .until(() -> vectorStore.similaritySearch(SearchRequest.query("Hello").withTopK(1)), hasSize(0)); }); } diff --git a/vector-stores/spring-ai-qdrant-store/pom.xml b/vector-stores/spring-ai-qdrant-store/pom.xml index 2d9c598e529..d2938dcb53d 100644 --- a/vector-stores/spring-ai-qdrant-store/pom.xml +++ b/vector-stores/spring-ai-qdrant-store/pom.xml @@ -39,6 +39,7 @@ 17 17 + false diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java index 1870a37aabe..944b67ede7d 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantFilterExpressionConverter.java @@ -30,15 +30,6 @@ import org.springframework.ai.vectorstore.filter.Filter.Operand; import org.springframework.ai.vectorstore.filter.Filter.Value; -import static io.qdrant.client.ConditionFactory.filter; -import static io.qdrant.client.ConditionFactory.match; -import static io.qdrant.client.ConditionFactory.matchExceptKeywords; -import static io.qdrant.client.ConditionFactory.matchExceptValues; -import static io.qdrant.client.ConditionFactory.matchKeyword; -import static io.qdrant.client.ConditionFactory.matchKeywords; -import static io.qdrant.client.ConditionFactory.matchValues; -import static io.qdrant.client.ConditionFactory.range; - /** * @author Anush Shetty * @since 0.8.1 @@ -57,15 +48,15 @@ protected Filter convertOperand(Operand operand) { if (operand instanceof Expression expression) { if (expression.type() == ExpressionType.NOT && expression.left() instanceof Group group) { - mustNotClauses.add(filter(convertOperand(group.content()))); + mustNotClauses.add(io.qdrant.client.ConditionFactory.filter(convertOperand(group.content()))); } else if (expression.type() == ExpressionType.AND) { - mustClauses.add(filter(convertOperand(expression.left()))); - mustClauses.add(filter(convertOperand(expression.right()))); + mustClauses.add(io.qdrant.client.ConditionFactory.filter(convertOperand(expression.left()))); + mustClauses.add(io.qdrant.client.ConditionFactory.filter(convertOperand(expression.right()))); } else if (expression.type() == ExpressionType.OR) { - shouldClauses.add(filter(convertOperand(expression.left()))); - shouldClauses.add(filter(convertOperand(expression.right()))); + shouldClauses.add(io.qdrant.client.ConditionFactory.filter(convertOperand(expression.left()))); + shouldClauses.add(io.qdrant.client.ConditionFactory.filter(convertOperand(expression.right()))); } else { if (!(expression.right() instanceof Value)) { @@ -83,44 +74,35 @@ protected Condition parseComparison(Key key, Value value, Expression exp) { ExpressionType type = exp.type(); switch (type) { - case EQ: { + case EQ: return buildEqCondition(key, value); - } - case NE: { + case NE: return buildNeCondition(key, value); - } - case GT: { + case GT: return buildGtCondition(key, value); - } - case GTE: { + case GTE: return buildGteCondition(key, value); - } - case LT: { + case LT: return buildLtCondition(key, value); - } - case LTE: { + case LTE: return buildLteCondition(key, value); - } - case IN: { + case IN: return buildInCondition(key, value); - } - case NIN: { + case NIN: return buildNInCondition(key, value); - } - default: { + default: throw new RuntimeException("Unsupported expression type: " + type); - } } } protected Condition buildEqCondition(Key key, Value value) { String identifier = doKey(key); if (value.value() instanceof String valueStr) { - return matchKeyword(identifier, valueStr); + return io.qdrant.client.ConditionFactory.matchKeyword(identifier, valueStr); } else if (value.value() instanceof Number valueNum) { long lValue = Long.parseLong(valueNum.toString()); - return match(identifier, lValue); + return io.qdrant.client.ConditionFactory.match(identifier, lValue); } throw new IllegalArgumentException("Invalid value type for EQ. Can either be a string or Number"); @@ -130,12 +112,14 @@ else if (value.value() instanceof Number valueNum) { protected Condition buildNeCondition(Key key, Value value) { String identifier = doKey(key); if (value.value() instanceof String valueStr) { - return filter(Filter.newBuilder().addMustNot(matchKeyword(identifier, valueStr)).build()); + return io.qdrant.client.ConditionFactory.filter(Filter.newBuilder() + .addMustNot(io.qdrant.client.ConditionFactory.matchKeyword(identifier, valueStr)) + .build()); } else if (value.value() instanceof Number valueNum) { long lValue = Long.parseLong(valueNum.toString()); - Condition condition = match(identifier, lValue); - return filter(Filter.newBuilder().addMustNot(condition).build()); + Condition condition = io.qdrant.client.ConditionFactory.match(identifier, lValue); + return io.qdrant.client.ConditionFactory.filter(Filter.newBuilder().addMustNot(condition).build()); } throw new IllegalArgumentException("Invalid value type for NEQ. Can either be a string or Number"); @@ -146,7 +130,7 @@ protected Condition buildGtCondition(Key key, Value value) { String identifier = doKey(key); if (value.value() instanceof Number valueNum) { Double dvalue = Double.parseDouble(valueNum.toString()); - return range(identifier, Range.newBuilder().setGt(dvalue).build()); + return io.qdrant.client.ConditionFactory.range(identifier, Range.newBuilder().setGt(dvalue).build()); } throw new RuntimeException("Unsupported value type for GT condition. Only supports Number"); @@ -156,7 +140,7 @@ protected Condition buildLtCondition(Key key, Value value) { String identifier = doKey(key); if (value.value() instanceof Number valueNum) { Double dvalue = Double.parseDouble(valueNum.toString()); - return range(identifier, Range.newBuilder().setLt(dvalue).build()); + return io.qdrant.client.ConditionFactory.range(identifier, Range.newBuilder().setLt(dvalue).build()); } throw new RuntimeException("Unsupported value type for LT condition. Only supports Number"); @@ -166,7 +150,7 @@ protected Condition buildGteCondition(Key key, Value value) { String identifier = doKey(key); if (value.value() instanceof Number valueNum) { Double dvalue = Double.parseDouble(valueNum.toString()); - return range(identifier, Range.newBuilder().setGte(dvalue).build()); + return io.qdrant.client.ConditionFactory.range(identifier, Range.newBuilder().setGte(dvalue).build()); } throw new RuntimeException("Unsupported value type for GTE condition. Only supports Number"); @@ -176,7 +160,7 @@ protected Condition buildLteCondition(Key key, Value value) { String identifier = doKey(key); if (value.value() instanceof Number valueNum) { Double dvalue = Double.parseDouble(valueNum.toString()); - return range(identifier, Range.newBuilder().setLte(dvalue).build()); + return io.qdrant.client.ConditionFactory.range(identifier, Range.newBuilder().setLte(dvalue).build()); } throw new RuntimeException("Unsupported value type for LTE condition. Only supports Number"); @@ -193,7 +177,7 @@ protected Condition buildInCondition(Key key, Value value) { for (Object valueObj : valueList) { stringValues.add(valueObj.toString()); } - return matchKeywords(identifier, stringValues); + return io.qdrant.client.ConditionFactory.matchKeywords(identifier, stringValues); } else if (firstValue instanceof Number) { // If the first value is a number, then all values should be numbers @@ -202,7 +186,7 @@ else if (firstValue instanceof Number) { Long longValue = Long.parseLong(valueObj.toString()); longValues.add(longValue); } - return matchValues(identifier, longValues); + return io.qdrant.client.ConditionFactory.matchValues(identifier, longValues); } else { throw new RuntimeException("Unsupported value in IN value list. Only supports String or Number"); @@ -224,7 +208,7 @@ protected Condition buildNInCondition(Key key, Value value) { for (Object valueObj : valueList) { stringValues.add(valueObj.toString()); } - return matchExceptKeywords(identifier, stringValues); + return io.qdrant.client.ConditionFactory.matchExceptKeywords(identifier, stringValues); } else if (firstValue instanceof Number) { // If the first value is a number, then all values should be numbers @@ -233,7 +217,7 @@ else if (firstValue instanceof Number) { Long longValue = Long.parseLong(valueObj.toString()); longValues.add(longValue); } - return matchExceptValues(identifier, longValues); + return io.qdrant.client.ConditionFactory.matchExceptValues(identifier, longValues); } else { throw new RuntimeException("Unsupported value in NIN value list. Only supports String or Number"); diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java index 00ad1e518cc..08220a53c19 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantObjectFactory.java @@ -32,7 +32,7 @@ * @author Anush Shetty * @since 0.8.1 */ -class QdrantObjectFactory { +final class QdrantObjectFactory { private static final Log logger = LogFactory.getLog(QdrantObjectFactory.class); diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java index 13862abc068..336384ae66b 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantValueFactory.java @@ -34,7 +34,7 @@ * @author Anush Shetty * @since 0.8.1 */ -class QdrantValueFactory { +final class QdrantValueFactory { private QdrantValueFactory() { } @@ -98,4 +98,4 @@ private static Value value(Map inputMap) { return Value.newBuilder().setStructValue(structBuilder).build(); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java index 8fadc237ac6..679707038ac 100644 --- a/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java +++ b/vector-stores/spring-ai-qdrant-store/src/main/java/org/springframework/ai/vectorstore/qdrant/QdrantVectorStore.java @@ -48,11 +48,6 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.util.Assert; -import static io.qdrant.client.PointIdFactory.id; -import static io.qdrant.client.ValueFactory.value; -import static io.qdrant.client.VectorsFactory.vectors; -import static io.qdrant.client.WithPayloadSelectorFactory.enable; - /** * Qdrant vectorStore implementation. This store supports creating, updating, deleting, * and similarity searching of documents in a Qdrant collection. @@ -148,8 +143,8 @@ public void doAdd(List documents) { List points = documents.stream() .map(document -> PointStruct.newBuilder() - .setId(id(UUID.fromString(document.getId()))) - .setVectors(vectors(document.getEmbedding())) + .setId(io.qdrant.client.PointIdFactory.id(UUID.fromString(document.getId()))) + .setVectors(io.qdrant.client.VectorsFactory.vectors(document.getEmbedding())) .putAllPayload(toPayload(document)) .build()) .toList(); @@ -169,7 +164,9 @@ public void doAdd(List documents) { @Override public Optional doDelete(List documentIds) { try { - List ids = documentIds.stream().map(id -> id(UUID.fromString(id))).toList(); + List ids = documentIds.stream() + .map(id -> io.qdrant.client.PointIdFactory.id(UUID.fromString(id))) + .toList(); var result = this.qdrantClient.deleteAsync(this.collectionName, ids) .get() .getStatus() == UpdateStatus.Completed; @@ -198,7 +195,7 @@ public List doSimilaritySearch(SearchRequest request) { var searchPoints = SearchPoints.newBuilder() .setCollectionName(this.collectionName) .setLimit(request.getTopK()) - .setWithPayload(enable(true)) + .setWithPayload(io.qdrant.client.WithPayloadSelectorFactory.enable(true)) .addAllVector(EmbeddingUtils.toList(queryEmbedding)) .setFilter(filter) .setScoreThreshold((float) request.getSimilarityThreshold()) @@ -206,9 +203,7 @@ public List doSimilaritySearch(SearchRequest request) { var queryResponse = this.qdrantClient.searchAsync(searchPoints).get(); - return queryResponse.stream().map(scoredPoint -> { - return toDocument(scoredPoint); - }).toList(); + return queryResponse.stream().map(this::toDocument).toList(); } catch (InterruptedException | ExecutionException | IllegalArgumentException e) { @@ -245,7 +240,7 @@ private Document toDocument(ScoredPoint point) { private Map toPayload(Document document) { try { var payload = QdrantValueFactory.toValueMap(document.getMetadata()); - payload.put(CONTENT_FIELD_NAME, value(document.getContent())); + payload.put(CONTENT_FIELD_NAME, io.qdrant.client.ValueFactory.value(document.getContent())); return payload; } catch (Exception e) { @@ -323,7 +318,7 @@ public static QdrantVectorStoreConfig defaultConfig() { return builder().build(); } - public static class Builder { + public final static class Builder { private String collectionName; diff --git a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java index 2045be309f5..a50241bc843 100644 --- a/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java +++ b/vector-stores/spring-ai-qdrant-store/src/test/java/org/springframework/ai/vectorstore/qdrant/QdrantImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class QdrantImage { +public final class QdrantImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("qdrant/qdrant:v1.9.7"); + private QdrantImage() { + + } + } diff --git a/vector-stores/spring-ai-redis-store/pom.xml b/vector-stores/spring-ai-redis-store/pom.xml index cf0c0e4d2ed..394f00aac60 100644 --- a/vector-stores/spring-ai-redis-store/pom.xml +++ b/vector-stores/spring-ai-redis-store/pom.xml @@ -41,6 +41,7 @@ 5.1.0 17 17 + false @@ -103,4 +104,4 @@ - \ No newline at end of file + diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java index 3a850f1dc88..a2fc25f9745 100644 --- a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/RedisVectorStore.java @@ -109,7 +109,7 @@ public class RedisVectorStore extends AbstractObservationVectorStore implements private static final Predicate RESPONSE_OK = Predicate.isEqual("OK"); - private static final Predicate RESPONSE_DEL_OK = Predicate.isEqual(1l); + private static final Predicate RESPONSE_DEL_OK = Predicate.isEqual(1L); private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32"; @@ -409,7 +409,7 @@ public static RedisVectorStoreConfig defaultConfig() { return builder().build(); } - public static class Builder { + public static final class Builder { private String indexName = DEFAULT_INDEX_NAME; diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java index c2c0901ba0c..07e8ef0e928 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisFilterExpressionConverterTests.java @@ -28,8 +28,6 @@ import org.springframework.ai.vectorstore.filter.Filter.Value; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric; -import static org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AND; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; @@ -51,7 +49,7 @@ private static RedisFilterExpressionConverter converter(MetadataField... fields) @Test void testEQ() { // country == "BG" - String vectorExpr = converter(tag("country")) + String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country")) .convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("@country:{BG}"); } @@ -59,7 +57,8 @@ void testEQ() { @Test void tesEqAndGte() { // genre == "drama" AND year >= 2020 - String vectorExpr = converter(tag("genre"), numeric("year")) + String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("genre"), + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year")) .convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")), new Expression(GTE, new Key("year"), new Value(2020)))); assertThat(vectorExpr).isEqualTo("@genre:{drama} @year:[2020 inf]"); @@ -68,15 +67,18 @@ void tesEqAndGte() { @Test void tesIn() { // genre in ["comedy", "documentary", "drama"] - String vectorExpr = converter(tag("genre")).convertExpression( - new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); + String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("genre")) + .convertExpression( + new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama")))); assertThat(vectorExpr).isEqualTo("@genre:{comedy | documentary | drama}"); } @Test void testNe() { // year >= 2020 OR country == "BG" AND city != "Sofia" - String vectorExpr = converter(numeric("year"), tag("country"), tag("city")) + String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year"), + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country"), + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("city")) .convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Group(new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")), new Expression(NE, new Key("city"), new Value("Sofia")))))); @@ -86,7 +88,9 @@ void testNe() { @Test void testGroup() { // (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"] - String vectorExpr = converter(numeric("year"), tag("country"), tag("city")) + String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year"), + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country"), + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("city")) .convertExpression(new Expression(AND, new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)), new Expression(EQ, new Key("country"), new Value("BG")))), @@ -97,7 +101,9 @@ void testGroup() { @Test void tesBoolean() { // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] - String vectorExpr = converter(numeric("year"), tag("country"), tag("isOpen")) + String vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("year"), + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country"), + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("isOpen")) .convertExpression(new Expression(AND, new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), new Expression(GTE, new Key("year"), new Value(2020))), @@ -109,7 +115,8 @@ void tesBoolean() { @Test void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 - String vectorExpr = converter(numeric("temperature")) + String vectorExpr = converter( + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.numeric("temperature")) .convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)), new Expression(LTE, new Key("temperature"), new Value(20.13)))); @@ -118,11 +125,12 @@ void testDecimal() { @Test void testComplexIdentifiers() { - String vectorExpr = converter(tag("country 1 2 3")) + String vectorExpr = converter( + org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country 1 2 3")) .convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG"))); assertThat(vectorExpr).isEqualTo("@\"country 1 2 3\":{BG}"); - vectorExpr = converter(tag("country 1 2 3")) + vectorExpr = converter(org.springframework.ai.vectorstore.RedisVectorStore.MetadataField.tag("country 1 2 3")) .convertExpression(new Expression(EQ, new Key("'country 1 2 3'"), new Value("BG"))); assertThat(vectorExpr).isEqualTo("@'country 1 2 3':{BG}"); } diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java index 124d76d5885..4153a82ea6e 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/RedisVectorStoreIT.java @@ -85,10 +85,8 @@ void cleanDatabase() { @Test void ensureIndexGetsCreated() { - this.contextRunner.run(context -> { - assertThat(context.getBean(RedisVectorStore.class).getJedis().ftList()) - .contains(RedisVectorStore.DEFAULT_INDEX_NAME); - }); + this.contextRunner.run(context -> assertThat(context.getBean(RedisVectorStore.class).getJedis().ftList()) + .contains(RedisVectorStore.DEFAULT_INDEX_NAME)); } @Test diff --git a/vector-stores/spring-ai-typesense-store/pom.xml b/vector-stores/spring-ai-typesense-store/pom.xml index 14087b76be1..860758ec40a 100644 --- a/vector-stores/spring-ai-typesense-store/pom.xml +++ b/vector-stores/spring-ai-typesense-store/pom.xml @@ -38,6 +38,11 @@ git@github.com:spring-projects/spring-ai.git + + false + + + org.springframework.ai @@ -87,4 +92,4 @@ - \ No newline at end of file + diff --git a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java index 5706dd15416..e3decb87020 100644 --- a/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java +++ b/vector-stores/spring-ai-typesense-store/src/main/java/org/springframework/ai/vectorstore/TypesenseFilterExpressionConverter.java @@ -73,4 +73,4 @@ protected void doKey(Filter.Key key, StringBuilder context) { context.append("metadata." + key.key() + ":"); } -} \ No newline at end of file +} diff --git a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java index ea06769bc9e..8e7bc24d59a 100644 --- a/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java +++ b/vector-stores/spring-ai-typesense-store/src/test/java/org/springframework/ai/vectorstore/TypesenseImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class TypesenseImage { +public final class TypesenseImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("typesense/typesense:27.1"); + private TypesenseImage() { + + } + } diff --git a/vector-stores/spring-ai-weaviate-store/pom.xml b/vector-stores/spring-ai-weaviate-store/pom.xml index c1ea5ff95be..7a1b19ef585 100644 --- a/vector-stores/spring-ai-weaviate-store/pom.xml +++ b/vector-stores/spring-ai-weaviate-store/pom.xml @@ -38,6 +38,7 @@ 17 17 + false diff --git a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java index 58a831eb10d..9ce01c9f059 100644 --- a/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java +++ b/vector-stores/spring-ai-weaviate-store/src/main/java/org/springframework/ai/vectorstore/WeaviateVectorStore.java @@ -487,7 +487,7 @@ public enum Type { } - public static class Builder { + public static final class Builder { private String objectClass = "SpringAiWeaviate"; diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java index 53275f567e7..f06f2ad5de1 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java @@ -51,9 +51,8 @@ public void testMissingFilterName() { FilterExpressionConverter converter = new WeaviateFilterExpressionConverter(List.of()); - assertThatThrownBy(() -> { - converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG"))); - }).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG")))) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining( "Not allowed filter identifier name: country. Consider adding it to WeaviateVectorStore#filterMetadataKeys."); } diff --git a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java index 3dbfcdd930c..8d13d9b302b 100644 --- a/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java +++ b/vector-stores/spring-ai-weaviate-store/src/test/java/org/springframework/ai/vectorstore/WeaviateImage.java @@ -21,8 +21,12 @@ /** * @author Thomas Vitale */ -public class WeaviateImage { +public final class WeaviateImage { public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("semitechnologies/weaviate:1.25.9"); + private WeaviateImage() { + + } + }