Skip to content

Commit 4bbcc72

Browse files
authored
Merge pull request #999 from mariofusco/jlama-log
Add optional logging to jlama requests and responses
2 parents db01040 + 629909e commit 4bbcc72

File tree

3 files changed

+90
-41
lines changed

3 files changed

+90
-41
lines changed

model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaChatModel.java

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import java.util.Optional;
99
import java.util.UUID;
1010

11+
import org.jboss.logging.Logger;
12+
1113
import com.github.tjake.jlama.model.AbstractModel;
1214
import com.github.tjake.jlama.model.functions.Generator;
1315
import com.github.tjake.jlama.safetensors.DType;
1416
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
1517
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
16-
import com.github.tjake.jlama.safetensors.prompt.Tool;
1718
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
1819
import com.github.tjake.jlama.safetensors.prompt.ToolResult;
1920
import com.github.tjake.jlama.util.JsonSupport;
@@ -35,9 +36,14 @@
3536
import dev.langchain4j.model.output.TokenUsage;
3637

3738
public class JlamaChatModel implements ChatLanguageModel {
39+
40+
private static final Logger log = Logger.getLogger(JlamaChatModel.class);
41+
3842
private final AbstractModel model;
3943
private final Float temperature;
4044
private final Integer maxTokens;
45+
private final Boolean logRequests;
46+
private final Boolean logResponses;
4147

4248
public JlamaChatModel(JlamaChatModelBuilder builder) {
4349

@@ -46,21 +52,27 @@ public JlamaChatModel(JlamaChatModelBuilder builder) {
4652
.withRetry(() -> registry.downloadModel(builder.modelName, Optional.ofNullable(builder.authToken)), 3);
4753

4854
JlamaModel.Loader loader = jlamaModel.loader();
49-
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime)
55+
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime) {
5056
loader = loader.quantized();
57+
}
5158

52-
if (builder.workingQuantizedType != null)
59+
if (builder.workingQuantizedType != null) {
5360
loader = loader.workingQuantizationType(builder.workingQuantizedType);
61+
}
5462

55-
if (builder.threadCount != null)
63+
if (builder.threadCount != null) {
5664
loader = loader.threadCount(builder.threadCount);
65+
}
5766

58-
if (builder.workingDirectory != null)
67+
if (builder.workingDirectory != null) {
5968
loader = loader.workingDirectory(builder.workingDirectory);
69+
}
6070

6171
this.model = loader.load();
6272
this.temperature = builder.temperature == null ? 0.3f : builder.temperature;
6373
this.maxTokens = builder.maxTokens == null ? model.getConfig().contextLength : builder.maxTokens;
74+
this.logRequests = builder.logRequests != null && builder.logRequests;
75+
this.logResponses = builder.logResponses != null && builder.logResponses;
6476
}
6577

6678
public static JlamaChatModelBuilder builder() {
@@ -74,9 +86,29 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
7486

7587
@Override
7688
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
77-
if (model.promptSupport().isEmpty())
89+
if (model.promptSupport().isEmpty()) {
7890
throw new UnsupportedOperationException("This model does not support chat generation");
91+
}
92+
93+
if (logRequests) {
94+
log.info("Request: " + messages);
95+
}
96+
97+
PromptSupport.Builder promptBuilder = promptBuilder(messages);
98+
Generator.Response r = model.generate(UUID.randomUUID(), promptContext(promptBuilder, toolSpecifications), temperature,
99+
maxTokens, (token, time) -> {
100+
});
101+
Response<AiMessage> aiResponse = Response.from(aiMessageForResponse(r),
102+
new TokenUsage(r.promptTokens, r.generatedTokens), toFinishReason(r.finishReason));
79103

104+
if (logResponses) {
105+
log.info("Response: " + aiResponse);
106+
}
107+
108+
return aiResponse;
109+
}
110+
111+
private PromptSupport.Builder promptBuilder(List<ChatMessage> messages) {
80112
PromptSupport.Builder promptBuilder = model.promptSupport().get().builder();
81113

82114
for (ChatMessage message : messages) {
@@ -86,17 +118,18 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi
86118
StringBuilder finalMessage = new StringBuilder();
87119
UserMessage userMessage = (UserMessage) message;
88120
for (Content content : userMessage.contents()) {
89-
if (content.type() != ContentType.TEXT)
121+
if (content.type() != ContentType.TEXT) {
90122
throw new UnsupportedOperationException("Unsupported content type: " + content.type());
91-
123+
}
92124
finalMessage.append(((TextContent) content).text());
93125
}
94126
promptBuilder.addUserMessage(finalMessage.toString());
95127
}
96128
case AI -> {
97129
AiMessage aiMessage = (AiMessage) message;
98-
if (aiMessage.text() != null)
130+
if (aiMessage.text() != null) {
99131
promptBuilder.addAssistantMessage(aiMessage.text());
132+
}
100133

101134
if (aiMessage.hasToolExecutionRequests())
102135
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
@@ -113,26 +146,26 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi
113146
default -> throw new IllegalArgumentException("Unsupported message type: " + message.type());
114147
}
115148
}
149+
return promptBuilder;
150+
}
116151

117-
List<Tool> tools = toolSpecifications.stream().map(JlamaModel::toTool).toList();
118-
119-
PromptContext promptContext = tools.isEmpty() ? promptBuilder.build() : promptBuilder.build(tools);
120-
Generator.Response r = model.generate(UUID.randomUUID(), promptContext, temperature, maxTokens, (token, time) -> {
121-
});
152+
private PromptContext promptContext(PromptSupport.Builder promptBuilder, List<ToolSpecification> toolSpecifications) {
153+
return toolSpecifications.isEmpty() ? promptBuilder.build()
154+
: promptBuilder.build(toolSpecifications.stream().map(JlamaModel::toTool).toList());
155+
}
122156

157+
private AiMessage aiMessageForResponse(Generator.Response r) {
123158
if (r.finishReason == Generator.FinishReason.TOOL_CALL) {
124159
List<ToolExecutionRequest> toolCalls = r.toolCalls.stream().map(f -> ToolExecutionRequest.builder()
125160
.name(f.getName())
126161
.id(f.getId())
127162
.arguments(JsonSupport.toJson(f.getParameters()))
128163
.build()).toList();
129164

130-
return Response.from(AiMessage.from(toolCalls), new TokenUsage(r.promptTokens, r.generatedTokens),
131-
toFinishReason(r.finishReason));
165+
return AiMessage.from(toolCalls);
132166
}
133167

134-
return Response.from(AiMessage.from(r.responseText), new TokenUsage(r.promptTokens, r.generatedTokens),
135-
toFinishReason(r.finishReason));
168+
return AiMessage.from(r.responseText);
136169
}
137170

138171
@Override
@@ -152,6 +185,8 @@ public static class JlamaChatModelBuilder {
152185
private DType workingQuantizedType;
153186
private Float temperature;
154187
private Integer maxTokens;
188+
private Boolean logRequests;
189+
private Boolean logResponses;
155190

156191
public JlamaChatModelBuilder modelCachePath(Optional<Path> modelCachePath) {
157192
this.modelCachePath = modelCachePath;
@@ -198,6 +233,16 @@ public JlamaChatModelBuilder maxTokens(Integer maxTokens) {
198233
return this;
199234
}
200235

236+
public JlamaChatModelBuilder logRequests(Boolean logRequests) {
237+
this.logRequests = logRequests;
238+
return this;
239+
}
240+
241+
public JlamaChatModelBuilder logResponses(Boolean logResponses) {
242+
this.logResponses = logResponses;
243+
return this;
244+
}
245+
201246
public JlamaChatModel build() {
202247
return new JlamaChatModel(this);
203248
}

model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/runtime/JlamaAiRecorder.java

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,11 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jJlamaConfig runtimeConfi
3535
.modelName(modelName)
3636
.modelCachePath(fixedRuntimeConfig.modelsPath());
3737

38-
if (chatModelConfig.temperature().isPresent()) {
39-
builder.temperature((float) chatModelConfig.temperature().getAsDouble());
40-
}
41-
if (chatModelConfig.maxTokens().isPresent()) {
42-
builder.maxTokens(chatModelConfig.maxTokens().getAsInt());
43-
}
38+
jlamaConfig.logRequests().ifPresent(builder::logRequests);
39+
jlamaConfig.logResponses().ifPresent(builder::logResponses);
40+
41+
chatModelConfig.temperature().ifPresent(temp -> builder.temperature((float) temp));
42+
chatModelConfig.maxTokens().ifPresent(builder::maxTokens);
4443

4544
return new Supplier<>() {
4645
@Override
@@ -72,9 +71,8 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jJlamaC
7271
.modelName(jlamaFixedRuntimeConfig.chatModel().modelName())
7372
.modelCachePath(fixedRuntimeConfig.modelsPath());
7473

75-
if (chatModelConfig.temperature().isPresent()) {
76-
builder.temperature((float) chatModelConfig.temperature().getAsDouble());
77-
}
74+
chatModelConfig.temperature().ifPresent(temp -> builder.temperature((float) temp));
75+
7876
return new Supplier<>() {
7977
@Override
8078
public StreamingChatLanguageModel get() {
@@ -121,25 +119,15 @@ public EmbeddingModel get() {
121119

122120
private LangChain4jJlamaConfig.JlamaConfig correspondingJlamaConfig(LangChain4jJlamaConfig runtimeConfig,
123121
String configName) {
124-
LangChain4jJlamaConfig.JlamaConfig jlamaConfig;
125-
if (NamedConfigUtil.isDefault(configName)) {
126-
jlamaConfig = runtimeConfig.defaultConfig();
127-
} else {
128-
jlamaConfig = runtimeConfig.namedConfig().get(configName);
129-
}
130-
return jlamaConfig;
122+
return NamedConfigUtil.isDefault(configName) ? runtimeConfig.defaultConfig()
123+
: runtimeConfig.namedConfig().get(configName);
131124
}
132125

133126
private LangChain4jJlamaFixedRuntimeConfig.JlamaConfig correspondingJlamaFixedRuntimeConfig(
134127
LangChain4jJlamaFixedRuntimeConfig runtimeConfig,
135128
String configName) {
136-
LangChain4jJlamaFixedRuntimeConfig.JlamaConfig jlamaConfig;
137-
if (NamedConfigUtil.isDefault(configName)) {
138-
jlamaConfig = runtimeConfig.defaultConfig();
139-
} else {
140-
jlamaConfig = runtimeConfig.namedConfig().get(configName);
141-
}
142-
return jlamaConfig;
129+
return NamedConfigUtil.isDefault(configName) ? runtimeConfig.defaultConfig()
130+
: runtimeConfig.namedConfig().get(configName);
143131
}
144132

145133
}

model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/runtime/config/LangChain4jJlamaConfig.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME;
44

55
import java.util.Map;
6+
import java.util.Optional;
67

8+
import io.quarkus.runtime.annotations.ConfigDocDefault;
79
import io.quarkus.runtime.annotations.ConfigDocMapKey;
810
import io.quarkus.runtime.annotations.ConfigDocSection;
911
import io.quarkus.runtime.annotations.ConfigGroup;
@@ -46,5 +48,19 @@ interface JlamaConfig {
4648
*/
4749
@WithDefault("true")
4850
Boolean enableIntegration();
51+
52+
/**
53+
* Whether Jlama should log requests
54+
*/
55+
@ConfigDocDefault("false")
56+
@WithDefault("${quarkus.langchain4j.log-requests}")
57+
Optional<Boolean> logRequests();
58+
59+
/**
60+
* Whether Jlama client should log responses
61+
*/
62+
@ConfigDocDefault("false")
63+
@WithDefault("${quarkus.langchain4j.log-responses}")
64+
Optional<Boolean> logResponses();
4965
}
5066
}

0 commit comments

Comments
 (0)