Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions models/spring-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
</scm>

<properties>
<disable.checks>false</disable.checks>
</properties>

<dependencies>

<!-- production dependencies -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public SpeechResponse call(SpeechPrompt speechPrompt) {
var speech = speechEntity.getBody();

if (speech == null) {
this.logger.warn("No speech response returned for speechRequest: {}", speechRequest);
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
return new SpeechResponse(new Speech(new byte[0]));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPro
var transcription = transcriptionEntity.getBody();

if (transcription == null) {
this.logger.warn("No transcription returned for request: {}", audioResource);
logger.warn("No transcription returned for request: {}", audioResource);
return new AudioTranscriptionResponse(null);
}

Expand All @@ -139,7 +139,7 @@ public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPro
var transcription = transcriptionEntity.getBody();

if (transcription == null) {
this.logger.warn("No transcription returned for request: {}", audioResource);
logger.warn("No transcription returned for request: {}", audioResource);
return new AudioTranscriptionResponse(null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,37 +122,48 @@ public int hashCode() {

@Override
public boolean equals(Object obj) {
if (this == obj)
if (this == obj) {
return true;
if (obj == null)
}
if (obj == null) {
return false;
if (getClass() != obj.getClass())
}
if (getClass() != obj.getClass()) {
return false;
}
OpenAiAudioTranscriptionOptions other = (OpenAiAudioTranscriptionOptions) obj;
if (this.model == null) {
if (other.model != null)
if (other.model != null) {
return false;
}
}
else if (!this.model.equals(other.model))
else if (!this.model.equals(other.model)) {
return false;
}
if (this.prompt == null) {
if (other.prompt != null)
if (other.prompt != null) {
return false;
}
}
else if (!this.prompt.equals(other.prompt))
else if (!this.prompt.equals(other.prompt)) {
return false;
}
if (this.language == null) {
if (other.language != null)
if (other.language != null) {
return false;
}
}
else if (!this.language.equals(other.language))
else if (!this.language.equals(other.language)) {
return false;
}
if (this.responseFormat == null) {
if (other.responseFormat != null)
if (other.responseFormat != null) {
return false;
}
}
else if (!this.responseFormat.equals(other.responseFormat))
else if (!this.responseFormat.equals(other.responseFormat)) {
return false;
}
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
@SuppressWarnings("null")
String id = chatCompletion2.id();

List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {// @formatter:off

List<Generation> generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
Expand Down Expand Up @@ -347,9 +346,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
}
})
.doOnError(observation::error)
.doFinally(s -> {
observation.stop();
})
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

Expand Down Expand Up @@ -454,10 +451,8 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;

toolMessage.getResponses().forEach(response -> {
Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id");
});

toolMessage.getResponses()
.forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageR
return new ImageResponse(List.of());
}

List<ImageGeneration> imageGenerationList = imageApiResponse.data().stream().map(entry -> {
return new ImageGeneration(new Image(entry.url(), entry.b64Json()),
new OpenAiImageGenerationMetadata(entry.revisedPrompt()));
}).toList();
List<ImageGeneration> imageGenerationList = imageApiResponse.data()
.stream()
.map(entry -> new ImageGeneration(new Image(entry.url(), entry.b64Json()),
new OpenAiImageGenerationMetadata(entry.revisedPrompt())))
.toList();

ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public String toString() {
+ ", user='" + this.user + '\'' + '}';
}

public static class Builder {
public static final class Builder {

private final OpenAiImageOptions options;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private ModerationResponse convertResponse(
OpenAiModerationApi.OpenAiModerationRequest openAiModerationRequest) {
OpenAiModerationApi.OpenAiModerationResponse moderationApiResponse = moderationResponseEntity.getBody();
if (moderationApiResponse == null) {
this.logger.warn("No moderation response returned for request: {}", openAiModerationRequest);
logger.warn("No moderation response returned for request: {}", openAiModerationRequest);
return new ModerationResponse(new Generation());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void setModel(String model) {
this.model = model;
}

public static class Builder {
public static final class Builder {

private final OpenAiModerationOptions options;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ private static Set<TypeReference> eval(Set<TypeReference> referenceSet) {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiApi.class)))
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiApi.class))) {
hints.reflection().registerType(tr, mcs);
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class)))
}
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) {
hints.reflection().registerType(tr, mcs);
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiImageApi.class)))
}
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiImageApi.class))) {
hints.reflection().registerType(tr, mcs);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> he
this.webClient = webClientBuilder
.baseUrl(baseUrl)
.defaultHeaders(finalHeaders)
.build();// @formatter:on
.build(); // @formatter:on
}

public static String getTextContent(List<ChatCompletionMessage.MediaContent> content) {
Expand Down Expand Up @@ -558,7 +558,8 @@ public enum Type {
/**
* Function tool type.
*/
@JsonProperty("function") FUNCTION
@JsonProperty("function")
FUNCTION
}

/**
Expand Down Expand Up @@ -588,7 +589,7 @@ public Function(String description, String name, String jsonSchema) {
this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema));
}
}
}// @formatter:on
} // @formatter:on

/**
* Creates a model response for the given chat conversation.
Expand Down Expand Up @@ -782,7 +783,7 @@ public static Object FUNCTION(String functionName) {
@JsonInclude(Include.NON_NULL)
public record ResponseFormat(
@JsonProperty("type") Type type,
@JsonProperty("json_schema") JsonSchema jsonSchema ) {
@JsonProperty("json_schema") JsonSchema jsonSchema) {

public ResponseFormat(Type type) {
this(type, (JsonSchema) null);
Expand All @@ -794,7 +795,7 @@ public ResponseFormat(Type type, String schema) {

@ConstructorBinding
public ResponseFormat(Type type, String name, String schema, Boolean strict) {
this(type, StringUtils.hasText(schema)? new JsonSchema(name, schema, strict): null);
this(type, StringUtils.hasText(schema) ? new JsonSchema(name, schema, strict) : null);
}

public enum Type {
Expand Down Expand Up @@ -837,7 +838,7 @@ public JsonSchema(String name, String schema) {
}

public JsonSchema(String name, String schema, Boolean strict) {
this(StringUtils.hasText(name)? name : "custom_schema", ModelOptionsUtils.jsonToMap(schema), strict);
this(StringUtils.hasText(name) ? name : "custom_schema", ModelOptionsUtils.jsonToMap(schema), strict);
}
}

Expand All @@ -856,7 +857,7 @@ public record StreamOptions(

public static StreamOptions INCLUDE_USAGE = new StreamOptions(true);
}
}// @formatter:on
} // @formatter:on

/**
* Message comprising the conversation.
Expand All @@ -880,7 +881,7 @@ public record ChatCompletionMessage(// @formatter:off
@JsonProperty("name") String name,
@JsonProperty("tool_call_id") String toolCallId,
@JsonProperty("tool_calls") List<ToolCall> toolCalls,
@JsonProperty("refusal") String refusal) {// @formatter:on
@JsonProperty("refusal") String refusal) { // @formatter:on

/**
* Create a chat completion message with the given content and role. All other
Expand Down Expand Up @@ -999,7 +1000,7 @@ public record ToolCall(// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("id") String id,
@JsonProperty("type") String type,
@JsonProperty("function") ChatCompletionFunction function) {// @formatter:on
@JsonProperty("function") ChatCompletionFunction function) { // @formatter:on

public ToolCall(String id, String type, ChatCompletionFunction function) {
this(null, id, type, function);
Expand All @@ -1017,7 +1018,7 @@ public ToolCall(String id, String type, ChatCompletionFunction function) {
@JsonInclude(Include.NON_NULL)
public record ChatCompletionFunction(// @formatter:off
@JsonProperty("name") String name,
@JsonProperty("arguments") String arguments) {// @formatter:on
@JsonProperty("arguments") String arguments) { // @formatter:on
}

}
Expand Down Expand Up @@ -1046,7 +1047,7 @@ public record ChatCompletion(// @formatter:off
@JsonProperty("model") String model,
@JsonProperty("system_fingerprint") String systemFingerprint,
@JsonProperty("object") String object,
@JsonProperty("usage") Usage usage) {// @formatter:on
@JsonProperty("usage") Usage usage) { // @formatter:on

/**
* Chat completion choice.
Expand All @@ -1061,7 +1062,7 @@ public record Choice(// @formatter:off
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("index") Integer index,
@JsonProperty("message") ChatCompletionMessage message,
@JsonProperty("logprobs") LogProbs logprobs) {// @formatter:on
@JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on

}

Expand Down Expand Up @@ -1094,7 +1095,7 @@ public record Content(// @formatter:off
@JsonProperty("token") String token,
@JsonProperty("logprob") Float logprob,
@JsonProperty("bytes") List<Integer> probBytes,
@JsonProperty("top_logprobs") List<TopLogProbs> topLogprobs) {// @formatter:on
@JsonProperty("top_logprobs") List<TopLogProbs> topLogprobs) { // @formatter:on

/**
* The most likely tokens and their log probability, at this token position.
Expand All @@ -1111,7 +1112,7 @@ public record Content(// @formatter:off
public record TopLogProbs(// @formatter:off
@JsonProperty("token") String token,
@JsonProperty("logprob") Float logprob,
@JsonProperty("bytes") List<Integer> probBytes) {// @formatter:on
@JsonProperty("bytes") List<Integer> probBytes) { // @formatter:on
}

}
Expand All @@ -1137,7 +1138,7 @@ public record Usage(// @formatter:off
@JsonProperty("prompt_tokens") Integer promptTokens,
@JsonProperty("total_tokens") Integer totalTokens,
@JsonProperty("prompt_tokens_details") PromptTokensDetails promptTokensDetails,
@JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on
@JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) { // @formatter:on

public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) {
this(completionTokens, promptTokens, totalTokens, null, null);
Expand All @@ -1150,7 +1151,7 @@ public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens
*/
@JsonInclude(Include.NON_NULL)
public record PromptTokensDetails(// @formatter:off
@JsonProperty("cached_tokens") Integer cachedTokens) {// @formatter:on
@JsonProperty("cached_tokens") Integer cachedTokens) { // @formatter:on
}

/**
Expand All @@ -1160,7 +1161,7 @@ public record PromptTokensDetails(// @formatter:off
*/
@JsonInclude(Include.NON_NULL)
public record CompletionTokenDetails(// @formatter:off
@JsonProperty("reasoning_tokens") Integer reasoningTokens) {// @formatter:on
@JsonProperty("reasoning_tokens") Integer reasoningTokens) { // @formatter:on
}

}
Expand Down Expand Up @@ -1190,7 +1191,7 @@ public record ChatCompletionChunk(// @formatter:off
@JsonProperty("model") String model,
@JsonProperty("system_fingerprint") String systemFingerprint,
@JsonProperty("object") String object,
@JsonProperty("usage") Usage usage) {// @formatter:on
@JsonProperty("usage") Usage usage) { // @formatter:on

/**
* Chat completion choice.
Expand All @@ -1205,7 +1206,7 @@ public record ChunkChoice(// @formatter:off
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("index") Integer index,
@JsonProperty("delta") ChatCompletionMessage delta,
@JsonProperty("logprobs") LogProbs logprobs) {// @formatter:on
@JsonProperty("logprobs") LogProbs logprobs) { // @formatter:on
}

}
Expand All @@ -1222,7 +1223,7 @@ public record ChunkChoice(// @formatter:off
public record Embedding(// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("embedding") float[] embedding,
@JsonProperty("object") String object) {// @formatter:on
@JsonProperty("object") String object) { // @formatter:on

/**
* Create an embedding with the given index, embedding and object type set to
Expand Down Expand Up @@ -1259,7 +1260,7 @@ public record EmbeddingRequest<T>(// @formatter:off
@JsonProperty("model") String model,
@JsonProperty("encoding_format") String encodingFormat,
@JsonProperty("dimensions") Integer dimensions,
@JsonProperty("user") String user) {// @formatter:on
@JsonProperty("user") String user) { // @formatter:on

/**
* Create an embedding request with the given input, model and encoding format set
Expand Down Expand Up @@ -1296,7 +1297,7 @@ public record EmbeddingList<T>(// @formatter:off
@JsonProperty("object") String object,
@JsonProperty("data") List<T> data,
@JsonProperty("model") String model,
@JsonProperty("usage") Usage usage) {// @formatter:on
@JsonProperty("usage") Usage usage) { // @formatter:on
}

}
Loading