Skip to content

Commit aa57737

Browse files
committed
Fix options merge for when no runtime options are given
1 parent c4ab49e commit aa57737

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageModel.java

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ public ImageResponse call(ImagePrompt imagePrompt) {
8888
// Merge the runtime options passed via the prompt with the default options
8989
// configured via the constructor.
9090
// Runtime options overwrite StabilityAiImageModel options
91-
StabilityAiImageOptions runtimeOptions = (StabilityAiImageOptions) imagePrompt.getOptions();
92-
StabilityAiImageOptions requestImageOptions = mergeOptions(runtimeOptions, this.defaultOptions);
91+
StabilityAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
9392

9493
// Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions data
9594
// types to the data types used in StabilityAiApi
@@ -118,29 +117,36 @@ private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse gener
118117
* Merge runtime and default {@link ImageOptions} to compute the final options to use
119118
* in the request.
120119
*/
121-
private StabilityAiImageOptions mergeOptions(StabilityAiImageOptions runtimeOptions,
122-
StabilityAiImageOptions defaultOptions) {
120+
private StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, StabilityAiImageOptions defaultOptions) {
123121
if (runtimeOptions == null) {
124122
return defaultOptions;
125123
}
126-
127-
return StabilityAiImageOptions.builder()
124+
StabilityAiImageOptions.Builder builder = StabilityAiImageOptions.builder()
128125
// Handle portable image options
129126
.withModel(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel()))
130127
.withN(ModelOptionsUtils.mergeOption(runtimeOptions.getN(), defaultOptions.getN()))
131128
.withResponseFormat(ModelOptionsUtils.mergeOption(runtimeOptions.getResponseFormat(),
132129
defaultOptions.getResponseFormat()))
133130
.withWidth(ModelOptionsUtils.mergeOption(runtimeOptions.getWidth(), defaultOptions.getWidth()))
134131
.withHeight(ModelOptionsUtils.mergeOption(runtimeOptions.getHeight(), defaultOptions.getHeight()))
135-
.withStylePreset(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle()))
132+
.withStylePreset(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle()));
133+
134+
if (runtimeOptions instanceof StabilityAiImageOptions) {
135+
StabilityAiImageOptions stabilityOptions = (StabilityAiImageOptions) runtimeOptions;
136136
// Handle Stability AI specific image options
137-
.withCfgScale(defaultOptions.getCfgScale())
138-
.withClipGuidancePreset(defaultOptions.getClipGuidancePreset())
139-
.withSampler(runtimeOptions.getSampler())
140-
.withSeed(defaultOptions.getSeed())
141-
.withSteps(runtimeOptions.getSteps())
142-
.withStylePreset(runtimeOptions.getStylePreset())
143-
.build();
137+
builder
138+
.withCfgScale(
139+
ModelOptionsUtils.mergeOption(stabilityOptions.getCfgScale(), defaultOptions.getCfgScale()))
140+
.withClipGuidancePreset(ModelOptionsUtils.mergeOption(stabilityOptions.getClipGuidancePreset(),
141+
defaultOptions.getClipGuidancePreset()))
142+
.withSampler(ModelOptionsUtils.mergeOption(stabilityOptions.getSampler(), defaultOptions.getSampler()))
143+
.withSeed(ModelOptionsUtils.mergeOption(stabilityOptions.getSeed(), defaultOptions.getSeed()))
144+
.withSteps(ModelOptionsUtils.mergeOption(stabilityOptions.getSteps(), defaultOptions.getSteps()))
145+
.withStylePreset(ModelOptionsUtils.mergeOption(stabilityOptions.getStylePreset(),
146+
defaultOptions.getStylePreset()));
147+
}
148+
149+
return builder.build();
144150
}
145151

146152
}

0 commit comments

Comments
 (0)