|
32 | 32 | import org.springframework.ai.huggingface.api.TextGenerationInferenceApi; |
33 | 33 | import org.springframework.ai.huggingface.invoker.ApiClient; |
34 | 34 | import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails; |
| 35 | +import org.springframework.ai.huggingface.model.CompatGenerateRequest; |
35 | 36 | import org.springframework.ai.huggingface.model.GenerateParameters; |
36 | | -import org.springframework.ai.huggingface.model.GenerateRequest; |
37 | 37 | import org.springframework.ai.huggingface.model.GenerateResponse; |
38 | 38 |
|
39 | 39 | /** |
40 | 40 | * An implementation of {@link ChatModel} that interfaces with HuggingFace Inference |
41 | 41 | * Endpoints for text generation. |
42 | 42 | * |
43 | 43 | * @author Mark Pollack |
| 44 | + * @author Jihoon Kim |
44 | 45 | */ |
45 | 46 | public class HuggingfaceChatModel implements ChatModel { |
46 | 47 |
|
@@ -89,22 +90,24 @@ public HuggingfaceChatModel(final String apiToken, String basePath) { |
89 | 90 | */ |
90 | 91 | @Override |
91 | 92 | public ChatResponse call(Prompt prompt) { |
92 | | - GenerateRequest generateRequest = new GenerateRequest(); |
93 | | - generateRequest.setInputs(prompt.getContents()); |
| 93 | + CompatGenerateRequest compatGenerateRequest = new CompatGenerateRequest(); |
| 94 | + compatGenerateRequest.setInputs(prompt.getContents()); |
94 | 95 | GenerateParameters generateParameters = new GenerateParameters(); |
95 | 96 | // TODO - need to expose API to set parameters per call. |
96 | 97 | generateParameters.setMaxNewTokens(this.maxNewTokens); |
97 | | - generateRequest.setParameters(generateParameters); |
98 | | - GenerateResponse generateResponse = this.textGenApi.generate(generateRequest); |
99 | | - String generatedText = generateResponse.getGeneratedText(); |
| 98 | + compatGenerateRequest.setParameters(generateParameters); |
| 99 | + List<GenerateResponse> generateResponses = this.textGenApi.compatGenerate(compatGenerateRequest); |
100 | 100 | List<Generation> generations = new ArrayList<>(); |
101 | | - AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails(); |
102 | | - Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails, |
103 | | - new TypeReference<Map<String, Object>>() { |
104 | | - |
105 | | - }); |
106 | | - Generation generation = new Generation(generatedText, detailsMap); |
107 | | - generations.add(generation); |
| 101 | + for (GenerateResponse generateResponse : generateResponses) { |
| 102 | + String generatedText = generateResponse.getGeneratedText(); |
| 103 | + AllOfGenerateResponseDetails allOfGenerateResponseDetails = generateResponse.getDetails(); |
| 104 | + Map<String, Object> detailsMap = this.objectMapper.convertValue(allOfGenerateResponseDetails, |
| 105 | + new TypeReference<Map<String, Object>>() { |
| 106 | + |
| 107 | + }); |
| 108 | + Generation generation = new Generation(generatedText, detailsMap); |
| 109 | + generations.add(generation); |
| 110 | + } |
108 | 111 | return new ChatResponse(generations); |
109 | 112 | } |
110 | 113 |
|
|
0 commit comments