Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ private void validateChatOptions(OCICohereChatOptions options) {
}

private List<Generation> getGenerations(Prompt prompt, OCICohereChatOptions options) {
com.oracle.bmc.generativeaiinference.responses.ChatResponse cr = genAi
com.oracle.bmc.generativeaiinference.responses.ChatResponse cr = this.genAi
.chat(toCohereChatRequest(prompt, options));
return toGenerations(cr, options);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
/*
* Copyright 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.oci;

import java.io.IOException;
Expand All @@ -22,6 +23,7 @@
import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider;
import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient;

import org.springframework.ai.oci.cohere.OCICohereChatOptions;

public class BaseOCIGenAITest {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
/*
* Copyright 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.oci.cohere;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
Expand All @@ -25,11 +27,10 @@
import org.springframework.ai.oci.BaseOCIGenAITest;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.oci.BaseOCIGenAITest.OCI_CHAT_MODEL_ID_KEY;
import static org.springframework.ai.oci.BaseOCIGenAITest.OCI_COMPARTMENT_ID_KEY;

@EnabledIfEnvironmentVariable(named = OCI_COMPARTMENT_ID_KEY, matches = ".+")
@EnabledIfEnvironmentVariable(named = OCI_CHAT_MODEL_ID_KEY, matches = ".+")
@EnabledIfEnvironmentVariable(named = org.springframework.ai.oci.BaseOCIGenAITest.OCI_COMPARTMENT_ID_KEY,
matches = ".+")
@EnabledIfEnvironmentVariable(named = org.springframework.ai.oci.BaseOCIGenAITest.OCI_CHAT_MODEL_ID_KEY, matches = ".+")
public class OCICohereChatModelIT extends BaseOCIGenAITest {

private static final ChatModel chatModel = new OCICohereChatModel(getGenerativeAIClient(), options().build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class OCICohereChatModelProperties {
.build();

public boolean isEnabled() {
return enabled;
return this.enabled;
}

public void setEnabled(boolean enabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
@AutoConfiguration
@ConditionalOnClass({ GenerativeAiInferenceClient.class, OCIEmbeddingModel.class })
@EnableConfigurationProperties({ OCIConnectionProperties.class, OCIEmbeddingModelProperties.class,
OCICohereChatModelProperties.class, })
OCICohereChatModelProperties.class })
public class OCIGenAiAutoConfiguration {

private static BasicAuthenticationDetailsProvider authenticationProvider(OCIConnectionProperties properties)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
/*
* Copyright 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.oci.genai;

import java.nio.file.Files;
Expand All @@ -23,6 +24,7 @@
import com.oracle.bmc.http.client.pki.Pem;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

import org.springframework.ai.oci.cohere.OCICohereChatModel;
import org.springframework.ai.oci.cohere.OCICohereChatOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ public class OCIGenAiAutoConfigurationIT {
private final ApplicationContextRunner cohereChatContextRunner = new ApplicationContextRunner().withPropertyValues(
// @formatter:off
"spring.ai.oci.genai.authenticationType=file",
"spring.ai.oci.genai.file=" + CONFIG_FILE,
"spring.ai.oci.genai.cohere.chat.options.compartment=" + COMPARTMENT_ID,
"spring.ai.oci.genai.file=" + this.CONFIG_FILE,
"spring.ai.oci.genai.cohere.chat.options.compartment=" + this.COMPARTMENT_ID,
"spring.ai.oci.genai.cohere.chat.options.servingMode=on-demand",
"spring.ai.oci.genai.cohere.chat.options.model=" + CHAT_MODEL_ID
"spring.ai.oci.genai.cohere.chat.options.model=" + this.CHAT_MODEL_ID
// @formatter:on
).withConfiguration(AutoConfigurations.of(OCIGenAiAutoConfiguration.class));

Expand Down
1 change: 1 addition & 0 deletions src/checkstyle/checkstyle-suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
<suppress files="FiltersParser\.java" checks="MultipleVariableDeclarations"/>
<suppress files="FiltersLexer\.java" checks="MultipleVariableDeclarations"/>
<suppress files="BaseOllamaIT.java" checks="HideUtilityClassConstructor"/>
<suppress files="BaseOCIGenAITest.java" checks="HideUtilityClassConstructor"/>

</suppressions>
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public void afterPropertiesSet() throws Exception {
this.documentChunks = this.session.getMap(this.mapName);
switch (this.indexType) {
case HNSW -> this.documentChunks
.addIndex(new HnswIndex<>(DocumentChunk::vector, this.distanceType.name(), dimensions));
.addIndex(new HnswIndex<>(DocumentChunk::vector, this.distanceType.name(), this.dimensions));
case BINARY -> this.documentChunks.addIndex(new BinaryQuantIndex<>(DocumentChunk::vector));
}
}
Expand All @@ -255,7 +255,7 @@ private Float32Vector toFloat32Vector(final float[] floats) {
}

String getMapName() {
return mapName;
return this.mapName;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,48 +92,6 @@ public static String getText(final String uri) {
.withPropertyValues("test.spring.ai.vectorstore.coherence.distanceType=COSINE",
"test.spring.ai.vectorstore.coherence.indexType=NONE");

@SpringBootConfiguration
@EnableAutoConfiguration
public static class TestClient {

@Value("${test.spring.ai.vectorstore.coherence.distanceType}")
CoherenceVectorStore.DistanceType distanceType;

@Value("${test.spring.ai.vectorstore.coherence.indexType}")
CoherenceVectorStore.IndexType indexType;

@Bean
public VectorStore vectorStore(EmbeddingModel embeddingModel, Session session) {
return new CoherenceVectorStore(embeddingModel, session).setDistanceType(distanceType)
.setIndexType(indexType)
.setForcedNormalization(distanceType == CoherenceVectorStore.DistanceType.COSINE
|| distanceType == CoherenceVectorStore.DistanceType.IP);
}

@Bean
public Session session(Coherence coherence) {
return coherence.getSession();
}

@Bean
public Coherence coherence() {
return Coherence.clusterMember().start().join();
}

@Bean
public EmbeddingModel embeddingModel() {
try {
TransformersEmbeddingModel tem = new TransformersEmbeddingModel();
tem.afterPropertiesSet();
return tem;
}
catch (Exception e) {
throw new RuntimeException("Failed initializing embedding model", e);
}
}

}

private static void truncateMap(ApplicationContext context, String mapName) {
Session session = context.getBean(Session.class);
session.getMap(mapName).truncate();
Expand All @@ -153,24 +111,24 @@ public static Stream<Arguments> distanceAndIndex() {
@ParameterizedTest(name = "Distance {0}, Index {1} : {displayName}")
@MethodSource("distanceAndIndex")
public void addAndSearch(CoherenceVectorStore.DistanceType distanceType, CoherenceVectorStore.IndexType indexType) {
contextRunner.withPropertyValues("test.spring.ai.vectorstore.coherence.distanceType=" + distanceType)
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.coherence.distanceType=" + distanceType)
.withPropertyValues("test.spring.ai.vectorstore.coherence.indexType=" + indexType)
.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);

vectorStore.add(documents);
vectorStore.add(this.documents);

List<Document> results = vectorStore
.similaritySearch(SearchRequest.query("What is Great Depression").withTopK(1));

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(2).getId());
assertThat(resultDoc.getMetadata()).containsKeys("meta2", "distance");

// Remove all documents from the store
vectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());
vectorStore.delete(this.documents.stream().map(doc -> doc.getId()).toList());

List<Document> results2 = vectorStore
.similaritySearch(SearchRequest.query("Great Depression").withTopK(1));
Expand All @@ -184,7 +142,7 @@ public void addAndSearch(CoherenceVectorStore.DistanceType distanceType, Coheren
@MethodSource("distanceAndIndex")
public void searchWithFilters(CoherenceVectorStore.DistanceType distanceType,
CoherenceVectorStore.IndexType indexType) {
contextRunner.withPropertyValues("test.spring.ai.vectorstore.coherence.distanceType=" + distanceType)
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.coherence.distanceType=" + distanceType)
.withPropertyValues("test.spring.ai.vectorstore.coherence.indexType=" + indexType)
.run(context -> {

Expand Down Expand Up @@ -250,7 +208,7 @@ public void searchWithFilters(CoherenceVectorStore.DistanceType distanceType,

@Test
public void documentUpdate() {
contextRunner.run(context -> {
this.contextRunner.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);

Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
Expand Down Expand Up @@ -286,11 +244,11 @@ public void documentUpdate() {

@Test
public void searchWithThreshold() {
contextRunner.run(context -> {
this.contextRunner.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);

vectorStore.add(documents);
vectorStore.add(this.documents);

List<Document> fullResult = vectorStore
.similaritySearch(SearchRequest.query("Time Shelter").withTopK(5).withSimilarityThresholdAll());
Expand All @@ -310,7 +268,7 @@ public void searchWithThreshold() {

assertThat(results).hasSize(1);
Document resultDoc = results.get(0);
assertThat(resultDoc.getId()).isEqualTo(documents.get(1).getId());
assertThat(resultDoc.getId()).isEqualTo(this.documents.get(1).getId());

truncateMap(context, ((CoherenceVectorStore) vectorStore).getMapName());
});
Expand Down Expand Up @@ -338,4 +296,46 @@ private static boolean isSortedByDistance(final List<Document> documents) {
return true;
}

@SpringBootConfiguration
@EnableAutoConfiguration
public static class TestClient {

@Value("${test.spring.ai.vectorstore.coherence.distanceType}")
CoherenceVectorStore.DistanceType distanceType;

@Value("${test.spring.ai.vectorstore.coherence.indexType}")
CoherenceVectorStore.IndexType indexType;

@Bean
public VectorStore vectorStore(EmbeddingModel embeddingModel, Session session) {
return new CoherenceVectorStore(embeddingModel, session).setDistanceType(this.distanceType)
.setIndexType(this.indexType)
.setForcedNormalization(this.distanceType == CoherenceVectorStore.DistanceType.COSINE
|| this.distanceType == CoherenceVectorStore.DistanceType.IP);
}

@Bean
public Session session(Coherence coherence) {
return coherence.getSession();
}

@Bean
public Coherence coherence() {
return Coherence.clusterMember().start().join();
}

@Bean
public EmbeddingModel embeddingModel() {
try {
TransformersEmbeddingModel tem = new TransformersEmbeddingModel();
tem.afterPropertiesSet();
return tem;
}
catch (Exception e) {
throw new RuntimeException("Failed initializing embedding model", e);
}
}

}

}