Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

Expand Down Expand Up @@ -78,12 +77,11 @@ public void processPage(PDPage page) throws IOException {
@Override
protected void writePage() throws IOException {
List<List<TextPosition>> charactersByArticle = super.getCharactersByArticle();
for (int i = 0; i < charactersByArticle.size(); i++) {
List<TextPosition> textList = charactersByArticle.get(i);
for (List<TextPosition> textList : charactersByArticle) {
try {
this.sortTextPositionList(textList);
}
catch (java.lang.IllegalArgumentException e) {
catch (IllegalArgumentException e) {
logger.error("Error sorting text positions", e);
}
this.iterateThroughTextList(textList.iterator());
Expand All @@ -106,7 +104,7 @@ private void writeToOutputStream(final List<TextLine> textLineList) throws IOExc
*/
private void sortTextPositionList(final List<TextPosition> textList) {
TextPositionComparator comparator = new TextPositionComparator();
Collections.sort(textList, comparator);
textList.sort(comparator);
}

private void writeLine(final List<TextPosition> textPositionList) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,16 @@ private PreparedStatement prepareGetStatement() {
private Message getMessage(UdtValue udt) {
String content = udt.getString(this.conf.messageUdtContentColumn);
Map<String, Object> props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn));
switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) {
case ASSISTANT:
return AssistantMessage.builder().content(content).properties(props).build();
case USER:
return UserMessage.builder().text(content).metadata(props).build();
case SYSTEM:
return SystemMessage.builder().text(content).metadata(props).build();
case TOOL:
return switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) {
case ASSISTANT -> AssistantMessage.builder().content(content).properties(props).build();
case USER -> UserMessage.builder().text(content).metadata(props).build();
case SYSTEM -> SystemMessage.builder().text(content).metadata(props).build();
case TOOL ->
// todo – persist ToolResponse somehow
return ToolResponseMessage.builder().responses(List.of()).metadata(props).build();
default:
throw new IllegalStateException(
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));
}
ToolResponseMessage.builder().responses(List.of()).metadata(props).build();
default -> throw new IllegalStateException(
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ ConverseRequest createRequest(Prompt prompt) {
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>();
if (message instanceof UserMessage) {
var userMessage = (UserMessage) message;
if (message instanceof UserMessage userMessage) {
contents.add(ContentBlock.fromText(userMessage.getText()));

if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,14 @@ public String toPrompt(List<Message> messages) {
}

protected String messageToString(Message message) {
switch (message.getMessageType()) {
case SYSTEM:
return message.getText();
case USER:
return this.humanPrompt + " " + message.getText();
case ASSISTANT:
return this.assistantPrompt + " " + message.getText();
case TOOL:
return switch (message.getMessageType()) {
case SYSTEM -> message.getText();
case USER -> this.humanPrompt + " " + message.getText();
case ASSISTANT -> this.assistantPrompt + " " + message.getText();
case TOOL ->
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}
};

throw new IllegalArgumentException("Unknown message type: " + message.getMessageType());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,11 @@ private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type)

Assert.notNull(type, "Message type must not be null");

switch (type) {
case SYSTEM:
case USER:
case TOOL:
return GeminiMessageType.USER;
case ASSISTANT:
return GeminiMessageType.MODEL;
default:
throw new IllegalArgumentException("Unsupported message type: " + type);
}
return switch (type) {
case SYSTEM, USER, TOOL -> GeminiMessageType.USER;
case ASSISTANT -> GeminiMessageType.MODEL;
default -> throw new IllegalArgumentException("Unsupported message type: " + type);
};
}

static List<Part> messageToGeminiParts(Message message) {
Expand Down Expand Up @@ -780,51 +775,38 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
// Helper methods for mapping safety settings enums
private static com.google.genai.types.HarmCategory mapToGenAiHarmCategory(
GoogleGenAiSafetySetting.HarmCategory category) {
switch (category) {
case HARM_CATEGORY_UNSPECIFIED:
return new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_UNSPECIFIED);
case HARM_CATEGORY_HATE_SPEECH:
return new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_HATE_SPEECH);
case HARM_CATEGORY_DANGEROUS_CONTENT:
return new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_DANGEROUS_CONTENT);
case HARM_CATEGORY_HARASSMENT:
return new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_HARASSMENT);
case HARM_CATEGORY_SEXUALLY_EXPLICIT:
return new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_SEXUALLY_EXPLICIT);
default:
throw new IllegalArgumentException("Unknown HarmCategory: " + category);
}
return switch (category) {
case HARM_CATEGORY_UNSPECIFIED -> new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_UNSPECIFIED);
case HARM_CATEGORY_HATE_SPEECH -> new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_HATE_SPEECH);
case HARM_CATEGORY_DANGEROUS_CONTENT -> new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_DANGEROUS_CONTENT);
case HARM_CATEGORY_HARASSMENT -> new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_HARASSMENT);
case HARM_CATEGORY_SEXUALLY_EXPLICIT -> new com.google.genai.types.HarmCategory(
com.google.genai.types.HarmCategory.Known.HARM_CATEGORY_SEXUALLY_EXPLICIT);
default -> throw new IllegalArgumentException("Unknown HarmCategory: " + category);
};
}

private static com.google.genai.types.HarmBlockThreshold mapToGenAiHarmBlockThreshold(
GoogleGenAiSafetySetting.HarmBlockThreshold threshold) {
switch (threshold) {
case HARM_BLOCK_THRESHOLD_UNSPECIFIED:
return new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.HARM_BLOCK_THRESHOLD_UNSPECIFIED);
case BLOCK_LOW_AND_ABOVE:
return new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_LOW_AND_ABOVE);
case BLOCK_MEDIUM_AND_ABOVE:
return new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_MEDIUM_AND_ABOVE);
case BLOCK_ONLY_HIGH:
return new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_ONLY_HIGH);
case BLOCK_NONE:
return new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_NONE);
case OFF:
return new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.OFF);
default:
throw new IllegalArgumentException("Unknown HarmBlockThreshold: " + threshold);
}
return switch (threshold) {
case HARM_BLOCK_THRESHOLD_UNSPECIFIED -> new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.HARM_BLOCK_THRESHOLD_UNSPECIFIED);
case BLOCK_LOW_AND_ABOVE -> new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_LOW_AND_ABOVE);
case BLOCK_MEDIUM_AND_ABOVE -> new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_MEDIUM_AND_ABOVE);
case BLOCK_ONLY_HIGH -> new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_ONLY_HIGH);
case BLOCK_NONE -> new com.google.genai.types.HarmBlockThreshold(
com.google.genai.types.HarmBlockThreshold.Known.BLOCK_NONE);
case OFF ->
new com.google.genai.types.HarmBlockThreshold(com.google.genai.types.HarmBlockThreshold.Known.OFF);
default -> throw new IllegalArgumentException("Unknown HarmBlockThreshold: " + threshold);
};
}

private List<Content> toGeminiContent(List<Message> instructions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,11 @@ private static String convertModality(MediaModality modality) {
String modalityStr = modality.toString().toUpperCase();

// Map SDK values to cleaner names
switch (modalityStr) {
case "TEXT":
case "IMAGE":
case "VIDEO":
case "AUDIO":
case "DOCUMENT":
return modalityStr;
case "MODALITY_UNSPECIFIED":
case "MEDIA_MODALITY_UNSPECIFIED":
return "UNKNOWN";
default:
return modalityStr;
}
return switch (modalityStr) {
case "TEXT", "IMAGE", "VIDEO", "AUDIO", "DOCUMENT" -> modalityStr;
case "MODALITY_UNSPECIFIED", "MEDIA_MODALITY_UNSPECIFIED" -> "UNKNOWN";
default -> modalityStr;
};
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,20 @@ public static GoogleGenAiTrafficType from(TrafficType trafficType) {
String typeStr = trafficType.toString().toUpperCase();

// Map SDK values to our enum values
switch (typeStr) {
case "ON_DEMAND":
return ON_DEMAND;
case "PROVISIONED_THROUGHPUT":
return PROVISIONED_THROUGHPUT;
case "TRAFFIC_TYPE_UNSPECIFIED":
return UNKNOWN;
default:
return switch (typeStr) {
case "ON_DEMAND" -> ON_DEMAND;
case "PROVISIONED_THROUGHPUT" -> PROVISIONED_THROUGHPUT;
case "TRAFFIC_TYPE_UNSPECIFIED" -> UNKNOWN;
default -> {
// Try exact match
for (GoogleGenAiTrafficType type : values()) {
if (type.value.equals(typeStr)) {
return type;
yield type;
}
}
return UNKNOWN;
}
yield UNKNOWN;
}
};
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,18 +457,13 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
}

private Stream<ChatCompletionMessage> createChatCompletionMessages(Message message) {
switch (message.getMessageType()) {
case USER:
return Stream.of(createUserChatCompletionMessage(message));
case SYSTEM:
return Stream.of(createSystemChatCompletionMessage(message));
case ASSISTANT:
return Stream.of(createAssistantChatCompletionMessage(message));
case TOOL:
return createToolChatCompletionMessages(message);
default:
throw new IllegalStateException("Unknown message type: " + message.getMessageType());
}
return switch (message.getMessageType()) {
case USER -> Stream.of(createUserChatCompletionMessage(message));
case SYSTEM -> Stream.of(createSystemChatCompletionMessage(message));
case ASSISTANT -> Stream.of(createAssistantChatCompletionMessage(message));
case TOOL -> createToolChatCompletionMessages(message);
default -> throw new IllegalStateException("Unknown message type: " + message.getMessageType());
};
}

private Stream<ChatCompletionMessage> createToolChatCompletionMessages(Message message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, StabilityAiIma
.seed(defaultOptions.getSeed())
.steps(defaultOptions.getSteps())
.stylePreset(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStylePreset()));
if (runtimeOptions instanceof StabilityAiImageOptions) {
StabilityAiImageOptions stabilityOptions = (StabilityAiImageOptions) runtimeOptions;
if (runtimeOptions instanceof StabilityAiImageOptions stabilityOptions) {
// Handle Stability AI specific image options
builder
.cfgScale(ModelOptionsUtils.mergeOption(stabilityOptions.getCfgScale(), defaultOptions.getCfgScale()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,11 @@ private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type)

Assert.notNull(type, "Message type must not be null");

switch (type) {
case SYSTEM:
case USER:
case TOOL:
return GeminiMessageType.USER;
case ASSISTANT:
return GeminiMessageType.MODEL;
default:
throw new IllegalArgumentException("Unsupported message type: " + type);
}
return switch (type) {
case SYSTEM, USER, TOOL -> GeminiMessageType.USER;
case ASSISTANT -> GeminiMessageType.MODEL;
default -> throw new IllegalArgumentException("Unsupported message type: " + type);
};
}

static List<Part> messageToGeminiParts(Message message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Categories)) {
if (!(o instanceof Categories that)) {
return false;
}
Categories that = (Categories) o;
return this.sexual == that.sexual && this.hate == that.hate && this.harassment == that.harassment
&& this.selfHarm == that.selfHarm && this.sexualMinors == that.sexualMinors
&& this.hateThreatening == that.hateThreatening && this.violenceGraphic == that.violenceGraphic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,9 @@ public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof CategoryScores)) {
if (!(o instanceof CategoryScores that)) {
return false;
}
CategoryScores that = (CategoryScores) o;
return Double.compare(that.sexual, this.sexual) == 0 && Double.compare(that.hate, this.hate) == 0
&& Double.compare(that.harassment, this.harassment) == 0
&& Double.compare(that.selfHarm, this.selfHarm) == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Moderation)) {
if (!(o instanceof Moderation that)) {
return false;
}
Moderation that = (Moderation) o;
return Objects.equals(this.id, that.id) && Objects.equals(this.model, that.model)
&& Objects.equals(this.results, that.results);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof ModerationMessage)) {
if (!(o instanceof ModerationMessage that)) {
return false;
}
ModerationMessage that = (ModerationMessage) o;
return Objects.equals(this.text, that.text);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,9 @@ public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof ModerationPrompt)) {
if (!(o instanceof ModerationPrompt that)) {
return false;
}
ModerationPrompt that = (ModerationPrompt) o;
return Objects.equals(this.message, that.message)
&& Objects.equals(this.moderationModelOptions, that.moderationModelOptions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,9 @@ public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof ModerationResult)) {
if (!(o instanceof ModerationResult that)) {
return false;
}
ModerationResult that = (ModerationResult) o;
return this.flagged == that.flagged && Objects.equals(this.categories, that.categories)
&& Objects.equals(this.categoryScores, that.categoryScores);
}
Expand Down
Loading