Skip to content

Commit 87fce09

Browse files
committed
Extend Hugging Face configuration with doSample, top-p, top-k and repetition penalty
1 parent 9143bef commit 87fce09

File tree

6 files changed

+246
-11
lines changed

6 files changed

+246
-11
lines changed

hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/QuarkusHuggingFaceChatModel.java

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import java.net.URL;
88
import java.time.Duration;
99
import java.util.List;
10+
import java.util.Optional;
11+
import java.util.OptionalDouble;
12+
import java.util.OptionalInt;
1013

1114
import dev.langchain4j.agent.tool.ToolSpecification;
1215
import dev.langchain4j.data.message.AiMessage;
@@ -34,9 +37,13 @@ public class QuarkusHuggingFaceChatModel implements ChatLanguageModel {
3437
private final Integer maxNewTokens;
3538
private final Boolean returnFullText;
3639
private final Boolean waitForModel;
40+
private final Optional<Boolean> doSample;
41+
private final OptionalDouble topP;
42+
private final OptionalInt topK;
43+
private final OptionalDouble repetitionPenalty;
3744

3845
private QuarkusHuggingFaceChatModel(Builder builder) {
39-
this.client = CLIENT_FACTORY.create(new HuggingFaceClientFactory.Input() {
46+
this.client = CLIENT_FACTORY.create(builder, new HuggingFaceClientFactory.Input() {
4047
@Override
4148
public String apiKey() {
4249
return builder.accessToken;
@@ -56,6 +63,10 @@ public Duration timeout() {
5663
this.maxNewTokens = builder.maxNewTokens;
5764
this.returnFullText = builder.returnFullText;
5865
this.waitForModel = builder.waitForModel;
66+
this.doSample = builder.doSample;
67+
this.topP = builder.topP;
68+
this.topK = builder.topK;
69+
this.repetitionPenalty = builder.repetitionPenalty;
5970
}
6071

6172
public static Builder builder() {
@@ -65,15 +76,23 @@ public static Builder builder() {
6576
@Override
6677
public Response<AiMessage> generate(List<ChatMessage> messages) {
6778

79+
Parameters.Builder builder = Parameters.builder()
80+
.temperature(temperature)
81+
.maxNewTokens(maxNewTokens)
82+
.returnFullText(returnFullText);
83+
84+
doSample.ifPresent(builder::doSample);
85+
topK.ifPresent(builder::topK);
86+
topP.ifPresent(builder::topP);
87+
repetitionPenalty.ifPresent(builder::repetitionPenalty);
88+
89+
Parameters parameters = builder
90+
.build();
6891
TextGenerationRequest request = TextGenerationRequest.builder()
6992
.inputs(messages.stream()
7093
.map(ChatMessage::text)
7194
.collect(joining("\n")))
72-
.parameters(Parameters.builder()
73-
.temperature(temperature)
74-
.maxNewTokens(maxNewTokens)
75-
.returnFullText(returnFullText)
76-
.build())
95+
.parameters(parameters)
7796
.options(Options.builder()
7897
.waitForModel(waitForModel)
7998
.build())
@@ -103,6 +122,14 @@ public static final class Builder {
103122
private Boolean returnFullText;
104123
private Boolean waitForModel = true;
105124
private URI url;
125+
private Optional<Boolean> doSample;
126+
127+
private OptionalInt topK;
128+
private OptionalDouble topP;
129+
130+
private OptionalDouble repetitionPenalty;
131+
public boolean logResponses;
132+
public boolean logRequests;
106133

107134
public Builder accessToken(String accessToken) {
108135
this.accessToken = accessToken;
@@ -143,8 +170,38 @@ public Builder waitForModel(Boolean waitForModel) {
143170
return this;
144171
}
145172

173+
public Builder doSample(Optional<Boolean> doSample) {
174+
this.doSample = doSample;
175+
return this;
176+
}
177+
178+
public Builder topK(OptionalInt topK) {
179+
this.topK = topK;
180+
return this;
181+
}
182+
183+
public Builder topP(OptionalDouble topP) {
184+
this.topP = topP;
185+
return this;
186+
}
187+
188+
public Builder repetitionPenalty(OptionalDouble repetitionPenalty) {
189+
this.repetitionPenalty = repetitionPenalty;
190+
return this;
191+
}
192+
146193
public QuarkusHuggingFaceChatModel build() {
147194
return new QuarkusHuggingFaceChatModel(this);
148195
}
196+
197+
public Builder logRequests(boolean logRequests) {
198+
this.logRequests = logRequests;
199+
return this;
200+
}
201+
202+
public Builder logResponses(boolean logResponses) {
203+
this.logResponses = logResponses;
204+
return this;
205+
}
149206
}
150207
}

hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/QuarkusHuggingFaceClientFactory.java

Lines changed: 138 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
package io.quarkiverse.langchain4j.huggingface;
22

3+
import static java.util.stream.Collectors.joining;
4+
import static java.util.stream.StreamSupport.stream;
5+
36
import java.net.URI;
47
import java.util.List;
58
import java.util.concurrent.TimeUnit;
9+
import java.util.regex.Matcher;
10+
import java.util.regex.Pattern;
11+
12+
import org.jboss.logging.Logger;
13+
import org.jboss.resteasy.reactive.client.api.ClientLogger;
14+
import org.jboss.resteasy.reactive.client.api.LoggingScope;
615

716
import dev.langchain4j.model.huggingface.client.EmbeddingRequest;
817
import dev.langchain4j.model.huggingface.client.HuggingFaceClient;
918
import dev.langchain4j.model.huggingface.client.TextGenerationRequest;
1019
import dev.langchain4j.model.huggingface.client.TextGenerationResponse;
1120
import dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory;
1221
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
22+
import io.vertx.core.Handler;
23+
import io.vertx.core.MultiMap;
24+
import io.vertx.core.buffer.Buffer;
25+
import io.vertx.core.http.HttpClientRequest;
26+
import io.vertx.core.http.HttpClientResponse;
1327

1428
public class QuarkusHuggingFaceClientFactory implements HuggingFaceClientFactory {
1529

@@ -18,12 +32,21 @@ public HuggingFaceClient create(Input input) {
1832
throw new UnsupportedOperationException("Should not be called");
1933
}
2034

21-
public HuggingFaceClient create(Input input, URI url) {
22-
HuggingFaceRestApi restApi = QuarkusRestClientBuilder.newBuilder()
35+
public HuggingFaceClient create(QuarkusHuggingFaceChatModel.Builder config, Input input, URI url) {
36+
QuarkusRestClientBuilder builder = QuarkusRestClientBuilder.newBuilder()
2337
.baseUri(url)
2438
.connectTimeout(input.timeout().toSeconds(), TimeUnit.SECONDS)
25-
.readTimeout(input.timeout().toSeconds(), TimeUnit.SECONDS)
39+
.readTimeout(input.timeout().toSeconds(), TimeUnit.SECONDS);
40+
41+
if (config != null && (config.logRequests || config.logResponses)) {
42+
builder.loggingScope(LoggingScope.REQUEST_RESPONSE);
43+
builder.clientLogger(new HuggingFaceClientLogger(config.logRequests,
44+
config.logResponses));
45+
}
46+
47+
HuggingFaceRestApi restApi = builder
2648
.build(HuggingFaceRestApi.class);
49+
2750
return new QuarkusHuggingFaceClient(restApi, input.apiKey());
2851
}
2952

@@ -61,4 +84,116 @@ public List<float[]> embed(EmbeddingRequest request) {
6184
return restApi.embed(request, token);
6285
}
6386
}
87+
88+
/**
89+
* Introduce a custom logger as the stock one logs at the DEBUG level by default...
90+
*/
91+
class HuggingFaceClientLogger implements ClientLogger {
92+
private static final Logger log = Logger.getLogger(HuggingFaceClientLogger.class);
93+
94+
private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s*sk-)(\\w{2})(\\w+)(\\w{2})");
95+
96+
private final boolean logRequests;
97+
private final boolean logResponses;
98+
99+
public HuggingFaceClientLogger(boolean logRequests, boolean logResponses) {
100+
this.logRequests = logRequests;
101+
this.logResponses = logResponses;
102+
}
103+
104+
@Override
105+
public void setBodySize(int bodySize) {
106+
// ignore
107+
}
108+
109+
@Override
110+
public void logRequest(HttpClientRequest request, Buffer body, boolean omitBody) {
111+
if (!logRequests || !log.isInfoEnabled()) {
112+
return;
113+
}
114+
try {
115+
log.infof("Request:\n- method: %s\n- url: %s\n- headers: %s\n- body: %s",
116+
request.getMethod(),
117+
request.absoluteURI(),
118+
inOneLine(request.headers()),
119+
bodyToString(body));
120+
} catch (Exception e) {
121+
log.warn("Failed to log request", e);
122+
}
123+
}
124+
125+
@Override
126+
public void logResponse(HttpClientResponse response, boolean redirect) {
127+
if (!logResponses || !log.isInfoEnabled()) {
128+
return;
129+
}
130+
response.bodyHandler(new Handler<>() {
131+
@Override
132+
public void handle(Buffer body) {
133+
try {
134+
log.infof(
135+
"Response:\n- status code: %s\n- headers: %s\n- body: %s",
136+
response.statusCode(),
137+
inOneLine(response.headers()),
138+
bodyToString(body));
139+
} catch (Exception e) {
140+
log.warn("Failed to log response", e);
141+
}
142+
}
143+
});
144+
}
145+
146+
private String bodyToString(Buffer body) {
147+
if (body == null) {
148+
return "";
149+
}
150+
return body.toString();
151+
}
152+
153+
private String inOneLine(MultiMap headers) {
154+
155+
return stream(headers.spliterator(), false)
156+
.map(header -> {
157+
String headerKey = header.getKey();
158+
String headerValue = header.getValue();
159+
if (headerKey.equals("Authorization")) {
160+
headerValue = maskAuthorizationHeaderValue(headerValue);
161+
} else if (headerKey.equals("api-key")) {
162+
headerValue = maskApiKeyHeaderValue(headerValue);
163+
}
164+
return String.format("[%s: %s]", headerKey, headerValue);
165+
})
166+
.collect(joining(", "));
167+
}
168+
169+
private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) {
170+
try {
171+
172+
Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue);
173+
174+
StringBuilder sb = new StringBuilder();
175+
while (matcher.find()) {
176+
matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4));
177+
}
178+
matcher.appendTail(sb);
179+
180+
return sb.toString();
181+
} catch (Exception e) {
182+
return "Failed to mask the API key.";
183+
}
184+
}
185+
186+
private static String maskApiKeyHeaderValue(String apiKeyHeaderValue) {
187+
try {
188+
if (apiKeyHeaderValue.length() <= 4) {
189+
return apiKeyHeaderValue;
190+
}
191+
return apiKeyHeaderValue.substring(0, 2)
192+
+ "..."
193+
+ apiKeyHeaderValue.substring(apiKeyHeaderValue.length() - 2);
194+
} catch (Exception e) {
195+
return "Failed to mask the API key.";
196+
}
197+
}
198+
}
64199
}

hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/QuarkusHuggingFaceEmbeddingModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public class QuarkusHuggingFaceEmbeddingModel implements EmbeddingModel {
3030
private final boolean waitForModel;
3131

3232
private QuarkusHuggingFaceEmbeddingModel(Builder builder) {
33-
this.client = CLIENT_FACTORY.create(new HuggingFaceClientFactory.Input() {
33+
this.client = CLIENT_FACTORY.create(null, new HuggingFaceClientFactory.Input() {
3434
@Override
3535
public String apiKey() {
3636
return builder.accessToken;

hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ public Supplier<?> chatModel(Langchain4jHuggingFaceConfig runtimeConfig) {
2727
.url(url)
2828
.timeout(runtimeConfig.timeout())
2929
.temperature(chatModelConfig.temperature())
30-
.waitForModel(chatModelConfig.waitForModel());
30+
.waitForModel(chatModelConfig.waitForModel())
31+
.doSample(chatModelConfig.doSample())
32+
.topP(chatModelConfig.topP())
33+
.topK(chatModelConfig.topK())
34+
.repetitionPenalty(chatModelConfig.repetitionPenalty())
35+
.logRequests(runtimeConfig.logRequests())
36+
.logResponses(runtimeConfig.logResponses());
3137

3238
if (apiKeyOpt.isPresent()) {
3339
builder.accessToken(apiKeyOpt.get());

hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/ChatModelConfig.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import java.net.URL;
44
import java.util.Optional;
5+
import java.util.OptionalDouble;
6+
import java.util.OptionalInt;
57

68
import io.quarkus.runtime.annotations.ConfigGroup;
79
import io.smallrye.config.WithDefault;
@@ -50,4 +52,27 @@ public interface ChatModelConfig {
5052
*/
5153
@WithDefault("true")
5254
Boolean waitForModel();
55+
56+
/**
57+
* Whether or not to use sampling ; use greedy decoding otherwise.
58+
*/
59+
Optional<Boolean> doSample();
60+
61+
/**
62+
* The number of highest probability vocabulary tokens to keep for top-k-filtering.
63+
*/
64+
OptionalInt topK();
65+
66+
/**
67+
* If set to less than {@code 1}, only the most probable tokens with probabilities that add up to {@code top_p} or
68+
* higher are kept for generation.
69+
*/
70+
OptionalDouble topP();
71+
72+
/**
73+
* The parameter for repetition penalty. 1.0 means no penalty.
74+
* See <a href="https://arxiv.org/pdf/1909.05858.pdf">this paper</a> for more details.
75+
*/
76+
OptionalDouble repetitionPenalty();
77+
5378
}

hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,16 @@ public interface Langchain4jHuggingFaceConfig {
3333
* Embedding model related settings
3434
*/
3535
EmbeddingModelConfig embeddingModel();
36+
37+
/**
38+
* Whether the OpenAI client should log requests
39+
*/
40+
@WithDefault("false")
41+
Boolean logRequests();
42+
43+
/**
44+
* Whether the OpenAI client should log responses
45+
*/
46+
@WithDefault("false")
47+
Boolean logResponses();
3648
}

0 commit comments

Comments
 (0)