From 10f22b55b2649277e21e0ef8fb9214023c2d3a75 Mon Sep 17 00:00:00 2001 From: Craig Walls Date: Fri, 25 Oct 2024 21:32:22 -0600 Subject: [PATCH 1/3] Merge runtime options for Stability AI image gen --- .../ai/stabilityai/StabilityAiImageModel.java | 14 ++++++++------ .../ai/stabilityai/api/StabilityAiApi.java | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java index 35a98059379..b97ef99171e 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java @@ -88,8 +88,9 @@ public ImageResponse call(ImagePrompt imagePrompt) { // Merge the runtime options passed via the prompt with the default options // configured via the constructor. // Runtime options overwrite StabilityAiImageModel options - StabilityAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions); - + StabilityAiImageOptions runtimeOptions = (StabilityAiImageOptions) imagePrompt.getOptions(); + StabilityAiImageOptions requestImageOptions = mergeOptions(runtimeOptions, this.defaultOptions); + System.err.println("requestImageOptions: " + requestImageOptions); // Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions data // types to the data types used in StabilityAiApi StabilityAiApi.GenerateImageRequest generateImageRequest = getGenerateImageRequest(imagePrompt, @@ -117,7 +118,8 @@ private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse gener * Merge runtime and default {@link ImageOptions} to compute the final options to use * in the request. */ - private StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, StabilityAiImageOptions defaultOptions) { + private StabilityAiImageOptions mergeOptions(StabilityAiImageOptions runtimeOptions, + StabilityAiImageOptions defaultOptions) { if (runtimeOptions == null) { return defaultOptions; } @@ -134,10 +136,10 @@ private StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, Stabil // Handle Stability AI specific image options .withCfgScale(defaultOptions.getCfgScale()) .withClipGuidancePreset(defaultOptions.getClipGuidancePreset()) - .withSampler(defaultOptions.getSampler()) + .withSampler(runtimeOptions.getSampler()) .withSeed(defaultOptions.getSeed()) - .withSteps(defaultOptions.getSteps()) - .withStylePreset(defaultOptions.getStylePreset()) + .withSteps(runtimeOptions.getSteps()) + .withStylePreset(runtimeOptions.getStylePreset()) .build(); } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java index 5b3b7f5460d..8c30e3034fe 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java @@ -100,7 +100,7 @@ public record GenerateImageRequest(@JsonProperty("text_prompts") List Date: Fri, 25 Oct 2024 21:38:59 -0600 Subject: [PATCH 2/3] Clean out debugging syserr --- .../springframework/ai/stabilityai/StabilityAiImageModel.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java index b97ef99171e..e8f1f825f09 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java @@ -90,7 +90,7 @@ public ImageResponse call(ImagePrompt imagePrompt) { // Runtime options overwrite StabilityAiImageModel options StabilityAiImageOptions runtimeOptions = (StabilityAiImageOptions) imagePrompt.getOptions(); StabilityAiImageOptions requestImageOptions = mergeOptions(runtimeOptions, this.defaultOptions); - System.err.println("requestImageOptions: " + requestImageOptions); + // Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions data // types to the data types used in StabilityAiApi StabilityAiApi.GenerateImageRequest generateImageRequest = getGenerateImageRequest(imagePrompt, From aa57737d5d96a7fe483b3f7e415b17d7d6e69799 Mon Sep 17 00:00:00 2001 From: Craig Walls Date: Fri, 25 Oct 2024 21:57:56 -0600 Subject: [PATCH 3/3] Fix options merge for when no runtime options are given --- .../ai/stabilityai/StabilityAiImageModel.java | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java index e8f1f825f09..aec8d2eb7e0 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java @@ -88,8 +88,7 @@ public ImageResponse call(ImagePrompt imagePrompt) { // Merge the runtime options passed via the prompt with the default options // configured via the constructor. // Runtime options overwrite StabilityAiImageModel options - StabilityAiImageOptions runtimeOptions = (StabilityAiImageOptions) imagePrompt.getOptions(); - StabilityAiImageOptions requestImageOptions = mergeOptions(runtimeOptions, this.defaultOptions); + StabilityAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions); // Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions data // types to the data types used in StabilityAiApi @@ -118,13 +117,11 @@ private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse gener * Merge runtime and default {@link ImageOptions} to compute the final options to use * in the request. */ - private StabilityAiImageOptions mergeOptions(StabilityAiImageOptions runtimeOptions, - StabilityAiImageOptions defaultOptions) { + private StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, StabilityAiImageOptions defaultOptions) { if (runtimeOptions == null) { return defaultOptions; } - - return StabilityAiImageOptions.builder() + StabilityAiImageOptions.Builder builder = StabilityAiImageOptions.builder() // Handle portable image options .withModel(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel())) .withN(ModelOptionsUtils.mergeOption(runtimeOptions.getN(), defaultOptions.getN())) @@ -132,15 +129,24 @@ private StabilityAiImageOptions mergeOptions(StabilityAiImageOptions runtimeOpti defaultOptions.getResponseFormat())) .withWidth(ModelOptionsUtils.mergeOption(runtimeOptions.getWidth(), defaultOptions.getWidth())) .withHeight(ModelOptionsUtils.mergeOption(runtimeOptions.getHeight(), defaultOptions.getHeight())) - .withStylePreset(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle())) + .withStylePreset(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle())); + + if (runtimeOptions instanceof StabilityAiImageOptions) { + StabilityAiImageOptions stabilityOptions = (StabilityAiImageOptions) runtimeOptions; // Handle Stability AI specific image options - .withCfgScale(defaultOptions.getCfgScale()) - .withClipGuidancePreset(defaultOptions.getClipGuidancePreset()) - .withSampler(runtimeOptions.getSampler()) - .withSeed(defaultOptions.getSeed()) - .withSteps(runtimeOptions.getSteps()) - .withStylePreset(runtimeOptions.getStylePreset()) - .build(); + builder + .withCfgScale( + ModelOptionsUtils.mergeOption(stabilityOptions.getCfgScale(), defaultOptions.getCfgScale())) + .withClipGuidancePreset(ModelOptionsUtils.mergeOption(stabilityOptions.getClipGuidancePreset(), + defaultOptions.getClipGuidancePreset())) + .withSampler(ModelOptionsUtils.mergeOption(stabilityOptions.getSampler(), defaultOptions.getSampler())) + .withSeed(ModelOptionsUtils.mergeOption(stabilityOptions.getSeed(), defaultOptions.getSeed())) + .withSteps(ModelOptionsUtils.mergeOption(stabilityOptions.getSteps(), defaultOptions.getSteps())) + .withStylePreset(ModelOptionsUtils.mergeOption(stabilityOptions.getStylePreset(), + defaultOptions.getStylePreset())); + } + + return builder.build(); } }