Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.qianfan;

import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.Image;
Expand All @@ -23,10 +24,16 @@
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.qianfan.api.QianFanConstants;
import org.springframework.ai.qianfan.api.QianFanImageApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

Expand All @@ -43,6 +50,8 @@ public class QianFanImageModel implements ImageModel {

private final static Logger logger = LoggerFactory.getLogger(QianFanImageModel.class);

private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention();

/**
* The default options used for the image completion requests.
*/
Expand All @@ -58,6 +67,16 @@ public class QianFanImageModel implements ImageModel {
*/
private final QianFanImageApi qianFanImageApi;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

/**
* Creates an instance of the QianFanImageModel.
* @param qianFanImageApi The QianFanImageApi instance to be used for interacting with
Expand All @@ -69,48 +88,85 @@ public QianFanImageModel(QianFanImageApi qianFanImageApi) {
}

/**
* Initializes a new instance of the QianFanImageModel.
* Creates an instance of the QianFanImageModel.
* @param qianFanImageApi The QianFanImageApi instance to be used for interacting with
* the QianFan Image API.
* @param options The QianFanImageOptions to configure the image model.
* @throws IllegalArgumentException if qianFanImageApi is null
*/
public QianFanImageModel(QianFanImageApi qianFanImageApi, QianFanImageOptions options) {
this(qianFanImageApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

/**
* Creates an instance of the QianFanImageModel.
* @param qianFanImageApi The QianFanImageApi instance to be used for interacting with
* the QianFan Image API.
* @param options The QianFanImageOptions to configure the image model.
* @param retryTemplate The retry template.
* @throws IllegalArgumentException if qianFanImageApi is null
*/
public QianFanImageModel(QianFanImageApi qianFanImageApi, QianFanImageOptions options,
RetryTemplate retryTemplate) {
this(qianFanImageApi, options, retryTemplate, ObservationRegistry.NOOP);
}

/**
* Initializes a new instance of the QianFanImageModel.
* @param qianFanImageApi The QianFanImageApi instance to be used for interacting with
* the QianFan Image API.
* @param options The QianFanImageOptions to configure the image model.
* @param retryTemplate The retry template.
* @param observationRegistry The ObservationRegistry used for instrumentation.
*/
public QianFanImageModel(QianFanImageApi qianFanImageApi, QianFanImageOptions options, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
Assert.notNull(qianFanImageApi, "QianFanImageApi must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
this.qianFanImageApi = qianFanImageApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public ImageResponse call(ImagePrompt imagePrompt) {
return this.retryTemplate.execute(ctx -> {
QianFanImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);

String instructions = imagePrompt.getInstructions().get(0).getText();
QianFanImageApi.QianFanImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);

QianFanImageApi.QianFanImageRequest imageRequest = new QianFanImageApi.QianFanImageRequest(instructions,
QianFanImageApi.DEFAULT_IMAGE_MODEL);
var observationContext = ImageModelObservationContext.builder()
.imagePrompt(imagePrompt)
.provider(QianFanConstants.PROVIDER_NAME)
.requestOptions(requestImageOptions)
.build();

if (this.defaultOptions != null) {
imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest,
QianFanImageApi.QianFanImageRequest.class);
}
return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {

if (imagePrompt.getOptions() != null) {
imageRequest = ModelOptionsUtils.merge(toQianFanImageOptions(imagePrompt.getOptions()), imageRequest,
QianFanImageApi.QianFanImageRequest.class);
}
ResponseEntity<QianFanImageApi.QianFanImageResponse> imageResponseEntity = this.retryTemplate
.execute(ctx -> this.qianFanImageApi.createImage(imageRequest));

// Make the request
ResponseEntity<QianFanImageApi.QianFanImageResponse> imageResponseEntity = this.qianFanImageApi
.createImage(imageRequest);
ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest);

// Convert to org.springframework.ai.model derived ImageResponse data type
return convertResponse(imageResponseEntity, imageRequest);
});
observationContext.setResponse(imageResponse);

return imageResponse;
});
}

private QianFanImageApi.QianFanImageRequest createRequest(ImagePrompt imagePrompt,
QianFanImageOptions requestImageOptions) {
String instructions = imagePrompt.getInstructions().get(0).getText();

QianFanImageApi.QianFanImageRequest imageRequest = new QianFanImageApi.QianFanImageRequest(instructions,
QianFanImageApi.DEFAULT_IMAGE_MODEL);

return ModelOptionsUtils.merge(requestImageOptions, imageRequest, QianFanImageApi.QianFanImageRequest.class);
}

private ImageResponse convertResponse(ResponseEntity<QianFanImageApi.QianFanImageResponse> imageResponseEntity,
Expand All @@ -132,33 +188,32 @@ private ImageResponse convertResponse(ResponseEntity<QianFanImageApi.QianFanImag
/**
* Convert the {@link ImageOptions} into {@link QianFanImageOptions}.
* @param runtimeImageOptions the image options to use.
* @param defaultOptions the default options.
* @return the converted {@link QianFanImageOptions}.
*/
private QianFanImageOptions toQianFanImageOptions(ImageOptions runtimeImageOptions) {
QianFanImageOptions.Builder qianFanImageOptionsBuilder = QianFanImageOptions.builder();
if (runtimeImageOptions != null) {
if (runtimeImageOptions.getN() != null) {
qianFanImageOptionsBuilder.withN(runtimeImageOptions.getN());
}
if (runtimeImageOptions.getModel() != null) {
qianFanImageOptionsBuilder.withModel(runtimeImageOptions.getModel());
}
if (runtimeImageOptions.getWidth() != null) {
qianFanImageOptionsBuilder.withWidth(runtimeImageOptions.getWidth());
}
if (runtimeImageOptions.getHeight() != null) {
qianFanImageOptionsBuilder.withHeight(runtimeImageOptions.getHeight());
}
if (runtimeImageOptions instanceof QianFanImageOptions runtimeQianFanImageOptions) {
if (runtimeQianFanImageOptions.getStyle() != null) {
qianFanImageOptionsBuilder.withStyle(runtimeQianFanImageOptions.getStyle());
}
if (runtimeQianFanImageOptions.getUser() != null) {
qianFanImageOptionsBuilder.withUser(runtimeQianFanImageOptions.getUser());
}
}
private QianFanImageOptions mergeOptions(@Nullable ImageOptions runtimeImageOptions,
QianFanImageOptions defaultOptions) {
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeImageOptions, ImageOptions.class,
QianFanImageOptions.class);

if (runtimeOptionsForProvider == null) {
return defaultOptions;
}
return qianFanImageOptionsBuilder.build();

return QianFanImageOptions.builder()
.withModel(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.withN(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
.withModel(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.withWidth(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
.withHeight(
ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
.withStyle(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getStyle(), defaultOptions.getStyle()))
.withUser(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getUser(), defaultOptions.getUser()))
.build();
}

public void setObservationConvention(ImageModelObservationConvention observationConvention) {
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.qianfan.image;

import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.qianfan.api.QianFanImageApi;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.retry.support.RetryTemplate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.image.observation.ImageModelObservationDocumentation.HighCardinalityKeyNames;
import static org.springframework.ai.image.observation.ImageModelObservationDocumentation.LowCardinalityKeyNames;

/**
* Integration tests for observation instrumentation in {@link QianFanImageModel}.
*
* @author Geng Rong
*/
@SpringBootTest(classes = QianFanImageModelObservationIT.Config.class)
@EnabledIfEnvironmentVariables(value = { @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+"),
@EnabledIfEnvironmentVariable(named = "QIANFAN_SECRET_KEY", matches = ".+") })
public class QianFanImageModelObservationIT {

@Autowired
TestObservationRegistry observationRegistry;

@Autowired
QianFanImageModel imageModel;

@Test
void observationForImageOperation() {
var options = QianFanImageOptions.builder()
.withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
.withHeight(1024)
.withWidth(1024)
.withStyle("Base")
.build();

var instructions = "Here comes the sun";

ImagePrompt imagePrompt = new ImagePrompt(instructions, options);

ImageResponse imageResponse = imageModel.call(imagePrompt);
assertThat(imageResponse.getResults()).hasSize(1);

TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME)
.that()
.hasContextualNameEqualTo("image " + QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.IMAGE.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.QIANFAN.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_IMAGE_SIZE.asString(), "1024x1024")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_IMAGE_STYLE.asString(), "Base")
.hasBeenStarted()
.hasBeenStopped();
}

@SpringBootConfiguration
static class Config {

@Bean
public TestObservationRegistry observationRegistry() {
return TestObservationRegistry.create();
}

@Bean
public QianFanImageApi qianFanImageApi() {
return new QianFanImageApi(System.getenv("QIANFAN_API_KEY"), System.getenv("QIANFAN_SECRET_KEY"));
}

@Bean
public QianFanImageModel qianFanImageModel(QianFanImageApi qianFanImageApi,
TestObservationRegistry observationRegistry) {
return new QianFanImageModel(qianFanImageApi, QianFanImageOptions.builder().build(),
RetryTemplate.defaultInstance(), observationRegistry);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.qianfan.QianFanChatModel;
import org.springframework.ai.qianfan.QianFanEmbeddingModel;
Expand Down Expand Up @@ -99,7 +100,8 @@ public QianFanEmbeddingModel qianFanEmbeddingModel(QianFanConnectionProperties c
matchIfMissing = true)
public QianFanImageModel qianFanImageModel(QianFanConnectionProperties commonProperties,
QianFanImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {
ResponseErrorHandler responseErrorHandler, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ImageModelObservationConvention> observationConvention) {

String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey()
: commonProperties.getApiKey();
Expand All @@ -116,7 +118,12 @@ public QianFanImageModel qianFanImageModel(QianFanConnectionProperties commonPro

var qianFanImageApi = new QianFanImageApi(baseUrl, apiKey, secretKey, restClientBuilder, responseErrorHandler);

return new QianFanImageModel(qianFanImageApi, imageProperties.getOptions(), retryTemplate);
var imageModel = new QianFanImageModel(qianFanImageApi, imageProperties.getOptions(), retryTemplate,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));

observationConvention.ifAvailable(imageModel::setObservationConvention);

return imageModel;
}

private QianFanApi qianFanApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey,
Expand Down