Skip to content

Commit 3c14fa6

Browse files
jitokimilayaperumalg
authored andcommitted
GH-1727 fix huggingface generate text
- Update GenerateResponse content schema type to array at openapi.json - Use CompatGenerateRequest instead of GenerateRequest for the TextGenerationInference API Request Signed-off-by: jitokim <[email protected]>
1 parent 7474852 commit 3c14fa6

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,16 @@
3232
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
3333
import org.springframework.ai.huggingface.invoker.ApiClient;
3434
import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
35+
import org.springframework.ai.huggingface.model.CompatGenerateRequest;
3536
import org.springframework.ai.huggingface.model.GenerateParameters;
36-
import org.springframework.ai.huggingface.model.GenerateRequest;
3737
import org.springframework.ai.huggingface.model.GenerateResponse;
3838

3939
/**
4040
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
4141
* Endpoints for text generation.
4242
*
4343
* @author Mark Pollack
44+
* @author Jihoon Kim
4445
*/
4546
public class HuggingfaceChatModel implements ChatModel {
4647

@@ -89,22 +90,24 @@ public HuggingfaceChatModel(final String apiToken, String basePath) {
8990
*/
9091
@Override
9192
public ChatResponse call(Prompt prompt) {
92-
GenerateRequest generateRequest = new GenerateRequest();
93-
generateRequest.setInputs(prompt.getContents());
93+
CompatGenerateRequest compatGenerateRequest = new CompatGenerateRequest();
94+
compatGenerateRequest.setInputs(prompt.getContents());
9495
GenerateParameters generateParameters = new GenerateParameters();
9596
// TODO - need to expose API to set parameters per call.
9697
generateParameters.setMaxNewTokens(this.maxNewTokens);
97-
generateRequest.setParameters(generateParameters);
98-
GenerateResponse generateResponse = this.textGenApi.generate(generateRequest);
99-
String generatedText = generateResponse.getGeneratedText();
98+
compatGenerateRequest.setParameters(generateParameters);
99+
List<GenerateResponse> generateResponses = this.textGenApi.compatGenerate(compatGenerateRequest);
100100
List<Generation> generations = new ArrayList<>();
101-
AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails();
102-
Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails,
103-
new TypeReference<Map<String, Object>>() {
104-
105-
});
106-
Generation generation = new Generation(generatedText, detailsMap);
107-
generations.add(generation);
101+
for (GenerateResponse generateResponse : generateResponses) {
102+
String generatedText = generateResponse.getGeneratedText();
103+
AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails();
104+
Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails,
105+
new TypeReference<Map<String, Object>>() {
106+
107+
});
108+
Generation generation = new Generation(generatedText, detailsMap);
109+
generations.add(generation);
110+
}
108111
return new ChatResponse(generations);
109112
}
110113

models/spring-ai-huggingface/src/main/resources/openapi.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737
"content": {
3838
"application/json": {
3939
"schema": {
40-
"$ref": "#/components/schemas/GenerateResponse"
40+
"type": "array",
41+
"items": {
42+
"$ref": "#/components/schemas/GenerateResponse"
43+
}
4144
}
4245
},
4346
"text/event-stream": {

models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,20 @@ void helloWorldCompletion() {
4343
lastname: Smith
4444
address: #1 Samuel St.
4545
Just generate the JSON object without explanations:
46+
Your response should be in JSON format.
47+
Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
48+
Do not include markdown code blocks in your response.
49+
Remove the ```json markdown from the output.
4650
[/INST]
4751
""";
4852
Prompt prompt = new Prompt(mistral7bInstruct);
4953
ChatResponse chatResponse = this.huggingfaceChatModel.call(prompt);
5054
assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty();
5155
String expectedResponse = """
5256
{
53-
"name": "John",
54-
"lastname": "Smith",
55-
"address": "#1 Samuel St."
57+
"name": "John",
58+
"lastname": "Smith",
59+
"address": "#1 Samuel St."
5660
}""";
5761
assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo(expectedResponse);
5862
assertThat(chatResponse.getResult().getOutput().getMetadata()).containsKey("generated_tokens");

0 commit comments

Comments
 (0)