Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest;
import org.springframework.ai.minimax.api.MiniMaxApi.FunctionTool;
import org.springframework.ai.minimax.api.MiniMaxApiConstants;
import org.springframework.ai.minimax.metadata.MiniMaxUsage;
import org.springframework.ai.model.ModelOptionsUtils;
Expand Down Expand Up @@ -508,11 +507,11 @@ else if (message.getMessageType() == MessageType.TOOL) {
return request;
}

private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
private List<MiniMaxApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
var function = new FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(),
functionCallback.getInputTypeSchema());
return new FunctionTool(function);
var function = new MiniMaxApi.FunctionTool.Function(functionCallback.getDescription(),
functionCallback.getName(), functionCallback.getInputTypeSchema());
return new MiniMaxApi.FunctionTool(function);
}).toList();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,35 @@ public String getValue() {

/**
* Represents a tool the model may call. Currently, only functions are supported as a tool.
*
* @param type The type of the tool. Currently, only 'function' is supported.
* @param function The function definition.
*/
@JsonInclude(Include.NON_NULL)
public record FunctionTool(
@JsonProperty("type") Type type,
@JsonProperty("function") Function function) {
@JsonInclude(JsonInclude.Include.NON_NULL)
public static class FunctionTool {

/**
* The type of the tool. Currently, only 'function' is supported.
*/
private Type type = Type.FUNCTION;

/**
* The function definition.
*/
private Function function;

public FunctionTool() {

}

/**
* Create a tool of type 'function' and the given function definition.
* @param type the tool type
* @param function function definition
*/
public FunctionTool(
@JsonProperty("type") Type type,
@JsonProperty("function") Function function) {
this.type = type;
this.function = function;
}

/**
* Create a tool of type 'function' and the given function definition.
Expand All @@ -348,8 +369,22 @@ public FunctionTool(Function function) {
this(Type.FUNCTION, function);
}

public static FunctionTool webSearchFunctionTool() {
return new FunctionTool(Type.WEB_SEARCH, null);
@JsonProperty("type")
public Type getType() {
return this.type;
}

@JsonProperty("function")
public Function getFunction() {
return this.function;
}

public void setType(Type type) {
this.type = type;
}

public void setFunction(Function function) {
this.function = function;
}

/**
Expand All @@ -361,35 +396,104 @@ public enum Type {
*/
@JsonProperty("function")
FUNCTION,

@JsonProperty("web_search")
WEB_SEARCH
}

public static FunctionTool webSearchFunctionTool() {
return new FunctionTool(FunctionTool.Type.WEB_SEARCH, null);
}


/**
* Function definition.
*
* @param description A description of what the function does, used by the model to choose when and how to call
* the function.
* @param name The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
* with a maximum length of 64.
* @param parameters The parameters the functions accepts, described as a JSON Schema object. To describe a
* function that accepts no parameters, provide the value {"type": "object", "properties": {}}.
*/
public record Function(
@JsonProperty("description") String description,
@JsonProperty("name") String name,
@JsonProperty("parameters") String parameters) {
*/
public static class Function {

@JsonProperty("description")
private String description;

@JsonProperty("name")
private String name;

@JsonProperty("parameters")
private Map<String, Object> parameters;

private String jsonSchema;

private Function() {

}

/**
* Create tool function definition.
*
* @param description A description of what the function does, used by the model to choose when and how to call
* the function.
* @param name The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes,
* with a maximum length of 64.
* @param parameters The parameters the functions accepts, described as a JSON Schema object. To describe a
* function that accepts no parameters, provide the value {"type": "object", "properties": {}}.
*/
public Function(
String description,
String name,
Map<String, Object> parameters) {
this.description = description;
this.name = name;
this.parameters = parameters;
}

/**
* Create tool function definition.
*
* @param description tool function description.
* @param name tool function name.
* @param parameters tool function schema.
* @param jsonSchema tool function schema as json.
*/
public Function(String description, String name, Map<String, Object> parameters) {
this(description, name, ModelOptionsUtils.toJsonString(parameters));
public Function(String description, String name, String jsonSchema) {
this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema));
}

@JsonProperty("description")
public String getDescription() {
return this.description;
}

@JsonProperty("name")
public String getName() {
return this.name;
}

@JsonProperty("parameters")
public Map<String, Object> getParameters() {
return this.parameters;
}

public void setDescription(String description) {
this.description = description;
}

public void setName(String name) {
this.name = name;
}

public void setParameters(Map<String, Object> parameters) {
this.parameters = parameters;
}

public String getJsonSchema() {
return this.jsonSchema;
}

public void setJsonSchema(String jsonSchema) {
this.jsonSchema = jsonSchema;
if (jsonSchema != null) {
this.parameters = ModelOptionsUtils.jsonToMap(jsonSchema);
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void promptOptionsTools() {
assertThat(request.model()).isEqualTo("PROMPT_MODEL");

assertThat(request.tools()).hasSize(1);
assertThat(request.tools().get(0).function().name()).isEqualTo(TOOL_FUNCTION_NAME);
assertThat(request.tools().get(0).getFunction().getName()).isEqualTo(TOOL_FUNCTION_NAME);
}

@Test
Expand Down Expand Up @@ -120,7 +120,7 @@ public void defaultOptionsTools() {
MiniMaxChatOptions.builder().withFunction(TOOL_FUNCTION_NAME).build()), false);

assertThat(request.tools()).hasSize(1);
assertThat(request.tools().get(0).function().name()).as("Explicitly enabled function")
assertThat(request.tools().get(0).getFunction().getName()).as("Explicitly enabled function")
.isEqualTo(TOOL_FUNCTION_NAME);

// Override the default options function with one from the prompt
Expand All @@ -134,7 +134,7 @@ public void defaultOptionsTools() {
false);

assertThat(request.tools()).hasSize(1);
assertThat(request.tools().get(0).function().name()).as("Explicitly enabled function")
assertThat(request.tools().get(0).getFunction().getName()).as("Explicitly enabled function")
.isEqualTo(TOOL_FUNCTION_NAME);

assertThat(client.getFunctionCallbackRegister()).hasSize(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest.ToolChoiceBuilder;
import org.springframework.ai.minimax.api.MiniMaxApi.FunctionTool.Type;
import org.springframework.http.ResponseEntity;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -67,31 +66,33 @@ public void toolFunctionCall() {
var message = new ChatCompletionMessage(
"What's the weather like in San Francisco? Return the temperature in Celsius.", Role.USER);

var functionTool = new MiniMaxApi.FunctionTool(Type.FUNCTION, new MiniMaxApi.FunctionTool.Function(
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"lat": {
"type": "number",
"description": "The city latitude"
},
"lon": {
"type": "number",
"description": "The city longitude"
},
"unit": {
"type": "string",
"enum": ["C", "F"]
var functionTool = new MiniMaxApi.FunctionTool(MiniMaxApi.FunctionTool.Type.FUNCTION,
new MiniMaxApi.FunctionTool.Function(
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather",
"""
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"lat": {
"type": "number",
"description": "The city latitude"
},
"lon": {
"type": "number",
"description": "The city longitude"
},
"unit": {
"type": "string",
"enum": ["C", "F"]
}
},
"required": ["location", "lat", "lon", "unit"]
}
},
"required": ["location", "lat", "lon", "unit"]
}
"""));
"""));

List<ChatCompletionMessage> messages = new ArrayList<>(List.of(message));

Expand Down
Loading