Skip to content

Commit a12293b

Browse files
committed
feat(spring-ai): Introduce FunctionCallback.Builder
- Introduce FunctionCallback.Builder interface for improved builder pattern. - Add support for generic type parameters via ResolvableType. Enhance function callbacks with generic type support. - Move CustomizedTypeReference to ModelOptionsUtils for broader reuse. - Deprecate old FunctionCallbackWrapper.Builder in favor of DefaultFunctionCallbackBuilder. - Add JSON schema generation support for ResolvableType. Resolves #1731
1 parent b4e0a45 commit a12293b

File tree

6 files changed

+383
-22
lines changed

6 files changed

+383
-22
lines changed

spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.ai.converter;
1818

19-
import java.lang.reflect.Type;
2019
import java.util.Objects;
2120

2221
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -37,6 +36,7 @@
3736
import org.slf4j.Logger;
3837
import org.slf4j.LoggerFactory;
3938

39+
import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference;
4040
import org.springframework.ai.util.JacksonUtils;
4141
import org.springframework.core.ParameterizedTypeReference;
4242
import org.springframework.lang.NonNull;
@@ -94,7 +94,7 @@ public BeanOutputConverter(Class<T> clazz, ObjectMapper objectMapper) {
9494
* @param typeRef The target class type reference.
9595
*/
9696
public BeanOutputConverter(ParameterizedTypeReference<T> typeRef) {
97-
this(new CustomizedTypeReference<>(typeRef), null);
97+
this(CustomizedTypeReference.forType(typeRef), null);
9898
}
9999

100100
/**
@@ -105,7 +105,7 @@ public BeanOutputConverter(ParameterizedTypeReference<T> typeRef) {
105105
* @param objectMapper Custom object mapper for JSON operations. endings.
106106
*/
107107
public BeanOutputConverter(ParameterizedTypeReference<T> typeRef, ObjectMapper objectMapper) {
108-
this(new CustomizedTypeReference<>(typeRef), objectMapper);
108+
this(CustomizedTypeReference.forType(typeRef), objectMapper);
109109
}
110110

111111
/**
@@ -220,19 +220,4 @@ public String getJsonSchema() {
220220
return this.jsonSchema;
221221
}
222222

223-
private static class CustomizedTypeReference<T> extends TypeReference<T> {
224-
225-
private final Type type;
226-
227-
CustomizedTypeReference(ParameterizedTypeReference<T> typeRef) {
228-
this.type = typeRef.getType();
229-
}
230-
231-
@Override
232-
public Type getType() {
233-
return this.type;
234-
}
235-
236-
}
237-
238223
}

spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.beans.PropertyDescriptor;
2020
import java.lang.reflect.Field;
21+
import java.lang.reflect.Type;
2122
import java.util.ArrayList;
2223
import java.util.Arrays;
2324
import java.util.HashMap;
@@ -50,6 +51,8 @@
5051
import org.springframework.ai.util.JacksonUtils;
5152
import org.springframework.beans.BeanWrapper;
5253
import org.springframework.beans.BeanWrapperImpl;
54+
import org.springframework.core.ParameterizedTypeReference;
55+
import org.springframework.core.ResolvableType;
5356
import org.springframework.util.Assert;
5457
import org.springframework.util.CollectionUtils;
5558
import org.springframework.util.ObjectUtils;
@@ -366,6 +369,41 @@ public static String getJsonSchema(Class<?> clazz, boolean toUpperCaseTypeValues
366369
return node.toPrettyString();
367370
}
368371

372+
/**
373+
* Generates JSON Schema (version 2020_12) for the given class.
374+
* @param clazz the class to generate JSON Schema for.
375+
* @param toUpperCaseTypeValues if true, the type values are converted to upper case.
376+
* @return the generated JSON Schema as a String.
377+
*/
378+
public static String getJsonSchema(ResolvableType inputType, boolean toUpperCaseTypeValues) {
379+
380+
if (SCHEMA_GENERATOR_CACHE.get() == null) {
381+
382+
JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED);
383+
Swagger2Module swaggerModule = new Swagger2Module();
384+
385+
SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12,
386+
OptionPreset.PLAIN_JSON)
387+
.with(Option.EXTRA_OPEN_API_FORMAT_VALUES)
388+
.with(Option.PLAIN_DEFINITION_KEYS)
389+
.with(swaggerModule)
390+
.with(jacksonModule);
391+
392+
SchemaGeneratorConfig config = configBuilder.build();
393+
SchemaGenerator generator = new SchemaGenerator(config);
394+
SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator);
395+
}
396+
397+
ObjectNode node = SCHEMA_GENERATOR_CACHE.get()
398+
.generateSchema(CustomizedTypeReference.forType(inputType).getType());
399+
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
400+
// version of it).
401+
toUpperCaseTypeValues(node);
402+
}
403+
404+
return node.toPrettyString();
405+
}
406+
369407
public static void toUpperCaseTypeValues(ObjectNode node) {
370408
if (node == null) {
371409
return;
@@ -405,4 +443,27 @@ public static <T> T mergeOption(T runtimeValue, T defaultValue) {
405443
return ObjectUtils.isEmpty(runtimeValue) ? defaultValue : runtimeValue;
406444
}
407445

446+
public static class CustomizedTypeReference<T> extends TypeReference<T> {
447+
448+
private final Type type;
449+
450+
public CustomizedTypeReference(ParameterizedTypeReference<T> typeRef) {
451+
this.type = typeRef.getType();
452+
}
453+
454+
@Override
455+
public Type getType() {
456+
return this.type;
457+
}
458+
459+
public static <T> CustomizedTypeReference<T> forType(ParameterizedTypeReference<T> typeRef) {
460+
return new CustomizedTypeReference<>(typeRef);
461+
}
462+
463+
public static <T> CustomizedTypeReference<T> forType(ResolvableType resolvableType) {
464+
return new CustomizedTypeReference<>(ParameterizedTypeReference.forType(resolvableType.getType()));
465+
}
466+
467+
}
468+
408469
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import com.fasterxml.jackson.databind.ObjectMapper;
2525

2626
import org.springframework.ai.chat.model.ToolContext;
27+
import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference;
28+
import org.springframework.core.ResolvableType;
2729
import org.springframework.util.Assert;
2830

2931
/**
@@ -47,7 +49,7 @@ abstract class AbstractFunctionCallback<I, O> implements BiFunction<I, ToolConte
4749

4850
private final String description;
4951

50-
private final Class<I> inputType;
52+
private final ResolvableType inputType;
5153

5254
private final String inputTypeSchema;
5355

@@ -72,6 +74,26 @@ abstract class AbstractFunctionCallback<I, O> implements BiFunction<I, ToolConte
7274
*/
7375
protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Class<I> inputType,
7476
Function<O, String> responseConverter, ObjectMapper objectMapper) {
77+
this(name, description, inputTypeSchema, ResolvableType.forClass(inputType), responseConverter, objectMapper);
78+
}
79+
80+
/**
81+
* Constructs a new {@link AbstractFunctionCallback} with the given name, description,
82+
* input type and default object mapper.
83+
* @param name Function name. Should be unique within the ChatModel's function
84+
* registry.
85+
* @param description Function description. Used as a "system prompt" by the model to
86+
* decide if the function should be called.
87+
* @param inputTypeSchema Used to compute, the argument's Schema (such as JSON Schema
88+
* or OpenAPI Schema)required by the Model's function calling protocol.
89+
* @param inputType Used to compute, the argument's JSON schema required by the
90+
* Model's function calling protocol.
91+
* @param responseConverter Used to convert the function's output type to a string.
92+
* @param objectMapper Used to convert the function's input and output types to and
93+
* from JSON.
94+
*/
95+
protected AbstractFunctionCallback(String name, String description, String inputTypeSchema,
96+
ResolvableType inputType, Function<O, String> responseConverter, ObjectMapper objectMapper) {
7597
Assert.notNull(name, "Name must not be null");
7698
Assert.notNull(description, "Description must not be null");
7799
Assert.notNull(inputType, "InputType must not be null");
@@ -116,9 +138,9 @@ public String call(String functionArguments) {
116138
return this.andThen(this.responseConverter).apply(request, null);
117139
}
118140

119-
private <T> T fromJson(String json, Class<T> targetClass) {
141+
private <T> T fromJson(String json, ResolvableType targetType) {
120142
try {
121-
return this.objectMapper.readValue(json, targetClass);
143+
return this.objectMapper.readValue(json, CustomizedTypeReference.forType(targetType));
122144
}
123145
catch (JsonProcessingException e) {
124146
throw new RuntimeException(e);

0 commit comments

Comments
 (0)