Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package org.springframework.ai.azure.openai;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.azure.ai.openai.OpenAIClient;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.mockito.Mockito;

import org.springframework.ai.document.MetadataMode;
Expand All @@ -33,27 +36,126 @@
*/
public class AzureEmbeddingsOptionsTests {

@Test
public void createRequestWithChatOptions() {
private OpenAIClient mockClient;

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
var client = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED,
private AzureOpenAiEmbeddingModel client;

@BeforeEach
void setUp() {
mockClient = Mockito.mock(OpenAIClient.class);
client = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED,
AzureOpenAiEmbeddingOptions.builder().deploymentName("DEFAULT_MODEL").user("USER_TEST").build());
}

@Test
public void createRequestWithChatOptions() {
var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"), null));

assertThat(requestOptions.getInput()).hasSize(1);

assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL");
assertThat(requestOptions.getUser()).isEqualTo("USER_TEST");

requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test message content"),
AzureOpenAiEmbeddingOptions.builder().deploymentName("PROMPT_MODEL").user("PROMPT_USER").build()));

assertThat(requestOptions.getInput()).hasSize(1);

assertThat(requestOptions.getModel()).isEqualTo("PROMPT_MODEL");
assertThat(requestOptions.getUser()).isEqualTo("PROMPT_USER");
}

@Test
public void createRequestWithMultipleInputs() {
List<String> inputs = Arrays.asList("First text", "Second text", "Third text");
var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(inputs, null));

assertThat(requestOptions.getInput()).hasSize(3);
assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL");
assertThat(requestOptions.getUser()).isEqualTo("USER_TEST");
}

@Test
public void createRequestWithEmptyInputs() {
var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(Collections.emptyList(), null));

assertThat(requestOptions.getInput()).isEmpty();
assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL");
}

@Test
public void createRequestWithNullOptions() {
var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test content"), null));

assertThat(requestOptions.getInput()).hasSize(1);
assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL");
assertThat(requestOptions.getUser()).isEqualTo("USER_TEST");
}

@Test
public void requestOptionsShouldOverrideDefaults() {
var customOptions = AzureOpenAiEmbeddingOptions.builder()
.deploymentName("CUSTOM_MODEL")
.user("CUSTOM_USER")
.build();

var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(List.of("Test content"), customOptions));

assertThat(requestOptions.getModel()).isEqualTo("CUSTOM_MODEL");
assertThat(requestOptions.getUser()).isEqualTo("CUSTOM_USER");
}

@Test
public void shouldPreserveInputOrder() {
List<String> orderedInputs = Arrays.asList("First", "Second", "Third", "Fourth");
var requestOptions = client.toEmbeddingOptions(new EmbeddingRequest(orderedInputs, null));

assertThat(requestOptions.getInput()).containsExactly("First", "Second", "Third", "Fourth");
}

@Test
public void shouldHandleDifferentMetadataModes() {
var clientWithNoneMode = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.NONE,
AzureOpenAiEmbeddingOptions.builder().deploymentName("TEST_MODEL").build());

var requestOptions = clientWithNoneMode.toEmbeddingOptions(new EmbeddingRequest(List.of("Test content"), null));

assertThat(requestOptions.getModel()).isEqualTo("TEST_MODEL");
assertThat(requestOptions.getInput()).hasSize(1);
}

@Test
public void shouldCreateOptionsBuilderWithAllParameters() {
var options = AzureOpenAiEmbeddingOptions.builder().deploymentName("test-deployment").user("test-user").build();

assertThat(options.getDeploymentName()).isEqualTo("test-deployment");
assertThat(options.getUser()).isEqualTo("test-user");
}

@Test
public void shouldValidateDeploymentNameNotNull() {
// This test assumes that the builder or model validates deployment name
// Adjust based on actual validation logic in your implementation
var optionsWithoutDeployment = AzureOpenAiEmbeddingOptions.builder().user("test-user").build();

// If there's validation, this should throw an exception
// Otherwise, adjust the test based on expected behavior
assertThat(optionsWithoutDeployment.getUser()).isEqualTo("test-user");
}

@Test
public void shouldHandleConcurrentRequests() {
// Test that multiple concurrent requests don't interfere with each other
var request1 = new EmbeddingRequest(List.of("First request"),
AzureOpenAiEmbeddingOptions.builder().deploymentName("MODEL1").user("USER1").build());
var request2 = new EmbeddingRequest(List.of("Second request"),
AzureOpenAiEmbeddingOptions.builder().deploymentName("MODEL2").user("USER2").build());

var options1 = client.toEmbeddingOptions(request1);
var options2 = client.toEmbeddingOptions(request2);

assertThat(options1.getModel()).isEqualTo("MODEL1");
assertThat(options1.getUser()).isEqualTo("USER1");
assertThat(options2.getModel()).isEqualTo("MODEL2");
assertThat(options2.getUser()).isEqualTo("USER2");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package org.springframework.ai.oci.cohere;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.oracle.bmc.generativeaiinference.model.CohereTool;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;

import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -31,6 +33,13 @@
*/
class OCICohereChatOptionsTests {

private OCICohereChatOptions options;

@BeforeEach
void setUp() {
options = new OCICohereChatOptions();
}

@Test
void testBuilderWithAllFields() {
OCICohereChatOptions options = OCICohereChatOptions.builder()
Expand All @@ -55,6 +64,34 @@ void testBuilderWithAllFields() {
0.6, 50, List.of("test"), 0.5, 0.5, List.of("doc1", "doc2"));
}

@Test
void testBuilderWithMinimalFields() {
OCICohereChatOptions options = OCICohereChatOptions.builder().model("minimal-model").build();

assertThat(options.getModel()).isEqualTo("minimal-model");
assertThat(options.getMaxTokens()).isNull();
assertThat(options.getTemperature()).isNull();
}

@Test
void testBuilderWithNullValues() {
OCICohereChatOptions options = OCICohereChatOptions.builder()
.model(null)
.maxTokens(null)
.temperature(null)
.stop(null)
.documents(null)
.tools(null)
.build();

assertThat(options.getModel()).isNull();
assertThat(options.getMaxTokens()).isNull();
assertThat(options.getTemperature()).isNull();
assertThat(options.getStop()).isNull();
assertThat(options.getDocuments()).isNull();
assertThat(options.getTools()).isNull();
}

@Test
void testCopy() {
OCICohereChatOptions original = OCICohereChatOptions.builder()
Expand Down Expand Up @@ -82,9 +119,20 @@ void testCopy() {
assertThat(copied.getTools()).isNotSameAs(original.getTools());
}

@Test
void testCopyWithNullValues() {
OCICohereChatOptions original = new OCICohereChatOptions();
OCICohereChatOptions copied = (OCICohereChatOptions) original.copy();

assertThat(copied).isNotSameAs(original).isEqualTo(original);
assertThat(copied.getModel()).isNull();
assertThat(copied.getStop()).isNull();
assertThat(copied.getDocuments()).isNull();
assertThat(copied.getTools()).isNull();
}

@Test
void testSetters() {
OCICohereChatOptions options = new OCICohereChatOptions();
options.setModel("test-model");
options.setMaxTokens(10);
options.setCompartment("test-compartment");
Expand Down Expand Up @@ -114,7 +162,6 @@ void testSetters() {

@Test
void testDefaultValues() {
OCICohereChatOptions options = new OCICohereChatOptions();
assertThat(options.getModel()).isNull();
assertThat(options.getMaxTokens()).isNull();
assertThat(options.getCompartment()).isNull();
Expand All @@ -130,4 +177,77 @@ void testDefaultValues() {
assertThat(options.getTools()).isNull();
}

@Test
void testBoundaryValues() {
options.setMaxTokens(0);
options.setTemperature(0.0);
options.setTopP(0.0);
options.setTopK(1);
options.setFrequencyPenalty(0.0);
options.setPresencePenalty(0.0);

assertThat(options.getMaxTokens()).isEqualTo(0);
assertThat(options.getTemperature()).isEqualTo(0.0);
assertThat(options.getTopP()).isEqualTo(0.0);
assertThat(options.getTopK()).isEqualTo(1);
assertThat(options.getFrequencyPenalty()).isEqualTo(0.0);
assertThat(options.getPresencePenalty()).isEqualTo(0.0);
}

@Test
void testMaximumBoundaryValues() {
options.setMaxTokens(Integer.MAX_VALUE);
options.setTemperature(1.0);
options.setTopP(1.0);
options.setTopK(Integer.MAX_VALUE);
options.setFrequencyPenalty(1.0);
options.setPresencePenalty(1.0);

assertThat(options.getMaxTokens()).isEqualTo(Integer.MAX_VALUE);
assertThat(options.getTemperature()).isEqualTo(1.0);
assertThat(options.getTopP()).isEqualTo(1.0);
assertThat(options.getTopK()).isEqualTo(Integer.MAX_VALUE);
assertThat(options.getFrequencyPenalty()).isEqualTo(1.0);
assertThat(options.getPresencePenalty()).isEqualTo(1.0);
}

@Test
void testEmptyCollections() {
options.setStop(Collections.emptyList());
options.setDocuments(Collections.emptyList());
options.setTools(Collections.emptyList());

assertThat(options.getStop()).isEmpty();
assertThat(options.getDocuments()).isEmpty();
assertThat(options.getTools()).isEmpty();
}

@Test
void testMultipleSetterCalls() {
options.setModel("first-model");
options.setModel("second-model");
options.setMaxTokens(50);
options.setMaxTokens(100);

assertThat(options.getModel()).isEqualTo("second-model");
assertThat(options.getMaxTokens()).isEqualTo(100);
}

@Test
void testNullSetters() {
// Set values first
options.setModel("test-model");
options.setMaxTokens(100);
options.setStop(List.of("test"));

// Then set to null
options.setModel(null);
options.setMaxTokens(null);
options.setStop(null);

assertThat(options.getModel()).isNull();
assertThat(options.getMaxTokens()).isNull();
assertThat(options.getStop()).isNull();
}

}
Loading