Skip to content

Commit 697a04b

Browse files
committed
Moonshot: No Function Calling support verificaiton
Update the docs and tests to verify that currently Moonshot ChatModel does not support Function Calling. Related to #1058
1 parent 18cdeee commit 697a04b

File tree

5 files changed

+125
-15
lines changed

5 files changed

+125
-15
lines changed

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
*/
1616
package org.springframework.ai.moonshot;
1717

18+
import java.util.HashMap;
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.concurrent.ConcurrentHashMap;
22+
1823
import org.slf4j.Logger;
1924
import org.slf4j.LoggerFactory;
2025
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2126
import org.springframework.ai.chat.model.ChatModel;
2227
import org.springframework.ai.chat.model.ChatResponse;
2328
import org.springframework.ai.chat.model.Generation;
24-
import org.springframework.ai.chat.model.StreamingChatModel;
2529
import org.springframework.ai.chat.prompt.ChatOptions;
2630
import org.springframework.ai.chat.prompt.Prompt;
2731
import org.springframework.ai.model.ModelOptionsUtils;
@@ -34,17 +38,13 @@
3438
import org.springframework.http.ResponseEntity;
3539
import org.springframework.retry.support.RetryTemplate;
3640
import org.springframework.util.Assert;
37-
import reactor.core.publisher.Flux;
3841

39-
import java.util.HashMap;
40-
import java.util.List;
41-
import java.util.Map;
42-
import java.util.concurrent.ConcurrentHashMap;
42+
import reactor.core.publisher.Flux;
4343

4444
/**
4545
* @author Geng Rong
4646
*/
47-
public class MoonshotChatModel implements ChatModel, StreamingChatModel {
47+
public class MoonshotChatModel implements ChatModel {
4848

4949
private static final Logger logger = LoggerFactory.getLogger(MoonshotChatModel.class);
5050

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.moonshot.chat;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
import java.util.stream.Collectors;
23+
24+
import org.junit.jupiter.api.Disabled;
25+
import org.junit.jupiter.api.Test;
26+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
import org.slf4j.Logger;
28+
import org.slf4j.LoggerFactory;
29+
import org.springframework.ai.chat.messages.AssistantMessage;
30+
import org.springframework.ai.chat.messages.Message;
31+
import org.springframework.ai.chat.messages.UserMessage;
32+
import org.springframework.ai.chat.model.ChatModel;
33+
import org.springframework.ai.chat.model.ChatResponse;
34+
import org.springframework.ai.chat.model.Generation;
35+
import org.springframework.ai.chat.prompt.Prompt;
36+
import org.springframework.ai.model.function.FunctionCallbackWrapper;
37+
import org.springframework.ai.moonshot.MoonshotChatOptions;
38+
import org.springframework.ai.moonshot.MoonshotTestConfiguration;
39+
import org.springframework.ai.moonshot.api.MockWeatherService;
40+
import org.springframework.ai.moonshot.api.MoonshotApi;
41+
import org.springframework.beans.factory.annotation.Autowired;
42+
import org.springframework.boot.test.context.SpringBootTest;
43+
44+
import reactor.core.publisher.Flux;
45+
46+
@Disabled("Currently, the Moonshot Chat Model doesn't support function calling.")
47+
@SpringBootTest(classes = MoonshotTestConfiguration.class)
48+
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
49+
class MoonshotChatModelFunctionCallingIT {
50+
51+
private static final Logger logger = LoggerFactory.getLogger(MoonshotChatModelFunctionCallingIT.class);
52+
53+
@Autowired
54+
ChatModel chatModel;
55+
56+
@Test
57+
void functionCallTest() {
58+
59+
UserMessage userMessage = new UserMessage(
60+
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
61+
62+
List<Message> messages = new ArrayList<>(List.of(userMessage));
63+
64+
var promptOptions = MoonshotChatOptions.builder()
65+
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
66+
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
67+
.withName("getCurrentWeather")
68+
.withDescription("Get the weather in location")
69+
.withResponseConverter((response) -> "" + response.temp() + response.unit())
70+
.build()))
71+
.build();
72+
73+
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
74+
75+
logger.info("Response: {}", response);
76+
77+
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
78+
}
79+
80+
@Test
81+
void streamFunctionCallTest() {
82+
83+
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
84+
85+
List<Message> messages = new ArrayList<>(List.of(userMessage));
86+
87+
var promptOptions = MoonshotChatOptions.builder()
88+
// .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
89+
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
90+
.withName("getCurrentWeather")
91+
.withDescription("Get the weather in location")
92+
.withResponseConverter((response) -> "" + response.temp() + response.unit())
93+
.build()))
94+
.build();
95+
96+
Flux<ChatResponse> response = chatModel.stream(new Prompt(messages, promptOptions));
97+
98+
String content = response.collectList()
99+
.block()
100+
.stream()
101+
.map(ChatResponse::getResults)
102+
.flatMap(List::stream)
103+
.map(Generation::getOutput)
104+
.map(AssistantMessage::getContent)
105+
.collect(Collectors.joining());
106+
logger.info("Response: {}", content);
107+
108+
assertThat(content).contains("30", "10", "15");
109+
}
110+
111+
}

models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelIT.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ void roleTest() {
7575
ChatResponse response = chatModel.call(prompt);
7676
assertThat(response.getResults()).hasSize(1);
7777
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
78-
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
7978
}
8079

8180
@Test
@@ -93,7 +92,7 @@ void listOutputConverter() {
9392
Prompt prompt = new Prompt(promptTemplate.createMessage());
9493
Generation generation = this.chatModel.call(prompt).getResult();
9594

96-
List<String> list = outputConverter.parse(generation.getOutput().getContent());
95+
List<String> list = outputConverter.convert(generation.getOutput().getContent());
9796
assertThat(list).hasSize(5);
9897

9998
}
@@ -118,7 +117,7 @@ void mapOutputConverter() {
118117
Prompt prompt = new Prompt(promptTemplate.createMessage());
119118
Generation generation = chatModel.call(prompt).getResult();
120119

121-
Map<String, Object> result = outputConverter.parse(generation.getOutput().getContent());
120+
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
122121
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
123122

124123
}
@@ -137,7 +136,7 @@ void beanOutputParser() {
137136
Prompt prompt = new Prompt(promptTemplate.createMessage());
138137
Generation generation = chatModel.call(prompt).getResult();
139138

140-
ActorsFilms actorsFilms = outputConverter.parse(generation.getOutput().getContent());
139+
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent());
141140

142141
}
143142

@@ -165,7 +164,7 @@ void beanOutputParserRecords() {
165164
Prompt prompt = new Prompt(promptTemplate.createMessage());
166165
Generation generation = chatModel.call(prompt).getResult();
167166

168-
ActorsFilmsRecord actorsFilms = outputConverter.parse(generation.getOutput().getContent());
167+
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
169168
logger.info("" + actorsFilms);
170169
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
171170
assertThat(actorsFilms.movies()).hasSize(5);
@@ -196,7 +195,7 @@ void beanStreamOutputParserRecords() {
196195
.map(AssistantMessage::getContent)
197196
.collect(Collectors.joining());
198197

199-
ActorsFilmsRecord actorsFilms = outputConverter.parse(generationTextFromStream);
198+
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
200199
logger.info("" + actorsFilms);
201200
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
202201
assertThat(actorsFilms.movies()).hasSize(5);

spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
*** xref:api/chat/minimax-chat.adoc[MiniMax]
2727
**** xref:api/chat/functions/minimax-chat-functions.adoc[Function Calling]
2828
*** xref:api/chat/moonshot-chat.adoc[Moonshot AI]
29-
**** xref:api/chat/functions/moonshot-chat-functions.adoc[Function Calling]
29+
//// **** xref:api/chat/functions/moonshot-chat-functions.adoc[Function Calling]
3030
*** xref:api/chat/ollama-chat.adoc[Ollama]
3131
*** xref:api/chat/openai-chat.adoc[OpenAI]
3232
**** xref:api/chat/functions/openai-chat-functions.adoc[Function Calling]

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/moonshot-chat.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,4 +247,4 @@ Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring-
247247
==== MoonshotApi Samples
248248
* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiIT.java[MoonshotApiIT.java] test provides some general examples how to use the lightweight library.
249249

250-
* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java.java[MoonshotApiToolFunctionCallIT.java] test shows how to use the low-level API to call tool functions.
250+
* The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/api/MoonshotApiToolFunctionCallIT.java[MoonshotApiToolFunctionCallIT.java] test shows how to use the low-level API to call tool functions.

0 commit comments

Comments
 (0)