Skip to content

Commit c960d80

Browse files
committed
fix huggingface generate text
update openapi.json Signed-off-by: jitokim <[email protected]>
1 parent b4e0a45 commit c960d80

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,16 @@
3232
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
3333
import org.springframework.ai.huggingface.invoker.ApiClient;
3434
import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
35+
import org.springframework.ai.huggingface.model.CompatGenerateRequest;
3536
import org.springframework.ai.huggingface.model.GenerateParameters;
36-
import org.springframework.ai.huggingface.model.GenerateRequest;
3737
import org.springframework.ai.huggingface.model.GenerateResponse;
3838

3939
/**
4040
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
4141
* Endpoints for text generation.
4242
*
4343
* @author Mark Pollack
44+
* @author Jihoon Kim
4445
*/
4546
public class HuggingfaceChatModel implements ChatModel {
4647

@@ -89,22 +90,24 @@ public HuggingfaceChatModel(final String apiToken, String basePath) {
8990
*/
9091
@Override
9192
public ChatResponse call(Prompt prompt) {
92-
GenerateRequest generateRequest = new GenerateRequest();
93-
generateRequest.setInputs(prompt.getContents());
93+
CompatGenerateRequest compatGenerateRequest = new CompatGenerateRequest();
94+
compatGenerateRequest.setInputs(prompt.getContents());
9495
GenerateParameters generateParameters = new GenerateParameters();
9596
// TODO - need to expose API to set parameters per call.
9697
generateParameters.setMaxNewTokens(this.maxNewTokens);
97-
generateRequest.setParameters(generateParameters);
98-
GenerateResponse generateResponse = this.textGenApi.generate(generateRequest);
99-
String generatedText = generateResponse.getGeneratedText();
98+
compatGenerateRequest.setParameters(generateParameters);
99+
List<GenerateResponse> generateResponses = this.textGenApi.compatGenerate(compatGenerateRequest);
100100
List<Generation> generations = new ArrayList<>();
101-
AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails();
102-
Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails,
103-
new TypeReference<Map<String, Object>>() {
104-
105-
});
106-
Generation generation = new Generation(generatedText, detailsMap);
107-
generations.add(generation);
101+
for (GenerateResponse generateResponse : generateResponses) {
102+
String generatedText = generateResponse.getGeneratedText();
103+
AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails();
104+
Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails,
105+
new TypeReference<Map<String, Object>>() {
106+
107+
});
108+
Generation generation = new Generation(generatedText, detailsMap);
109+
generations.add(generation);
110+
}
108111
return new ChatResponse(generations);
109112
}
110113

models/spring-ai-huggingface/src/main/resources/openapi.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737
"content": {
3838
"application/json": {
3939
"schema": {
40-
"$ref": "#/components/schemas/GenerateResponse"
40+
"type": "array",
41+
"items": {
42+
"$ref": "#/components/schemas/GenerateResponse"
43+
}
4144
}
4245
},
4346
"text/event-stream": {

0 commit comments

Comments
 (0)