Skip to content

Commit d21b911

Browse files
authored
Merge pull request #998 from quarkiverse/jlama-intree
Use custom integration with Jlama
2 parents 9cd3448 + 68a55b1 commit d21b911

File tree

9 files changed

+733
-5
lines changed

9 files changed

+733
-5
lines changed

model-providers/jlama/runtime/pom.xml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
</dependency>
2121

2222
<dependency>
23-
<groupId>dev.langchain4j</groupId>
24-
<artifactId>langchain4j-jlama</artifactId>
23+
<groupId>com.github.tjake</groupId>
24+
<artifactId>jlama-core</artifactId>
25+
<version>${jlama.version}</version>
2526
</dependency>
2627

2728
<dependency>
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package io.quarkiverse.langchain4j.jlama;
2+
3+
import static io.quarkiverse.langchain4j.jlama.JlamaModel.toFinishReason;
4+
5+
import java.nio.file.Path;
6+
import java.util.LinkedHashMap;
7+
import java.util.List;
8+
import java.util.Optional;
9+
import java.util.UUID;
10+
11+
import com.github.tjake.jlama.model.AbstractModel;
12+
import com.github.tjake.jlama.model.functions.Generator;
13+
import com.github.tjake.jlama.safetensors.DType;
14+
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
15+
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
16+
import com.github.tjake.jlama.safetensors.prompt.Tool;
17+
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
18+
import com.github.tjake.jlama.safetensors.prompt.ToolResult;
19+
import com.github.tjake.jlama.util.JsonSupport;
20+
21+
import dev.langchain4j.agent.tool.ToolExecutionRequest;
22+
import dev.langchain4j.agent.tool.ToolSpecification;
23+
import dev.langchain4j.data.message.AiMessage;
24+
import dev.langchain4j.data.message.ChatMessage;
25+
import dev.langchain4j.data.message.Content;
26+
import dev.langchain4j.data.message.ContentType;
27+
import dev.langchain4j.data.message.SystemMessage;
28+
import dev.langchain4j.data.message.TextContent;
29+
import dev.langchain4j.data.message.ToolExecutionResultMessage;
30+
import dev.langchain4j.data.message.UserMessage;
31+
import dev.langchain4j.internal.Json;
32+
import dev.langchain4j.internal.RetryUtils;
33+
import dev.langchain4j.model.chat.ChatLanguageModel;
34+
import dev.langchain4j.model.output.Response;
35+
import dev.langchain4j.model.output.TokenUsage;
36+
37+
public class JlamaChatModel implements ChatLanguageModel {
38+
private final AbstractModel model;
39+
private final Float temperature;
40+
private final Integer maxTokens;
41+
42+
public JlamaChatModel(JlamaChatModelBuilder builder) {
43+
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(builder.modelCachePath);
44+
JlamaModel jlamaModel = RetryUtils
45+
.withRetry(() -> registry.downloadModel(builder.modelName, Optional.ofNullable(builder.authToken)), 3);
46+
47+
JlamaModel.Loader loader = jlamaModel.loader();
48+
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime)
49+
loader = loader.quantized();
50+
51+
if (builder.workingQuantizedType != null)
52+
loader = loader.workingQuantizationType(builder.workingQuantizedType);
53+
54+
if (builder.threadCount != null)
55+
loader = loader.threadCount(builder.threadCount);
56+
57+
if (builder.workingDirectory != null)
58+
loader = loader.workingDirectory(builder.workingDirectory);
59+
60+
this.model = loader.load();
61+
this.temperature = builder.temperature == null ? 0.3f : builder.temperature;
62+
this.maxTokens = builder.maxTokens == null ? model.getConfig().contextLength : builder.maxTokens;
63+
}
64+
65+
public static JlamaChatModelBuilder builder() {
66+
return new JlamaChatModelBuilder();
67+
}
68+
69+
@Override
70+
public Response<AiMessage> generate(List<ChatMessage> messages) {
71+
return generate(messages, List.of());
72+
}
73+
74+
@Override
75+
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
76+
if (model.promptSupport().isEmpty())
77+
throw new UnsupportedOperationException("This model does not support chat generation");
78+
79+
PromptSupport.Builder promptBuilder = model.promptSupport().get().builder();
80+
81+
for (ChatMessage message : messages) {
82+
switch (message.type()) {
83+
case SYSTEM -> promptBuilder.addSystemMessage(((SystemMessage) message).text());
84+
case USER -> {
85+
StringBuilder finalMessage = new StringBuilder();
86+
UserMessage userMessage = (UserMessage) message;
87+
for (Content content : userMessage.contents()) {
88+
if (content.type() != ContentType.TEXT)
89+
throw new UnsupportedOperationException("Unsupported content type: " + content.type());
90+
91+
finalMessage.append(((TextContent) content).text());
92+
}
93+
promptBuilder.addUserMessage(finalMessage.toString());
94+
}
95+
case AI -> {
96+
AiMessage aiMessage = (AiMessage) message;
97+
if (aiMessage.text() != null)
98+
promptBuilder.addAssistantMessage(aiMessage.text());
99+
100+
if (aiMessage.hasToolExecutionRequests())
101+
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
102+
ToolCall toolCall = new ToolCall(toolExecutionRequest.name(), toolExecutionRequest.id(),
103+
Json.fromJson(toolExecutionRequest.arguments(), LinkedHashMap.class));
104+
promptBuilder.addToolCall(toolCall);
105+
}
106+
}
107+
case TOOL_EXECUTION_RESULT -> {
108+
ToolExecutionResultMessage toolMessage = (ToolExecutionResultMessage) message;
109+
ToolResult result = ToolResult.from(toolMessage.toolName(), toolMessage.id(), toolMessage.text());
110+
promptBuilder.addToolResult(result);
111+
}
112+
default -> throw new IllegalArgumentException("Unsupported message type: " + message.type());
113+
}
114+
}
115+
116+
List<Tool> tools = toolSpecifications.stream().map(JlamaModel::toTool).toList();
117+
118+
PromptContext promptContext = tools.isEmpty() ? promptBuilder.build() : promptBuilder.build(tools);
119+
Generator.Response r = model.generate(UUID.randomUUID(), promptContext, temperature, maxTokens, (token, time) -> {
120+
});
121+
122+
if (r.finishReason == Generator.FinishReason.TOOL_CALL) {
123+
List<ToolExecutionRequest> toolCalls = r.toolCalls.stream().map(f -> ToolExecutionRequest.builder()
124+
.name(f.getName())
125+
.id(f.getId())
126+
.arguments(JsonSupport.toJson(f.getParameters()))
127+
.build()).toList();
128+
129+
return Response.from(AiMessage.from(toolCalls), new TokenUsage(r.promptTokens, r.generatedTokens),
130+
toFinishReason(r.finishReason));
131+
}
132+
133+
return Response.from(AiMessage.from(r.responseText), new TokenUsage(r.promptTokens, r.generatedTokens),
134+
toFinishReason(r.finishReason));
135+
}
136+
137+
@Override
138+
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
139+
return generate(messages, List.of(toolSpecification));
140+
}
141+
142+
public static class JlamaChatModelBuilder {
143+
144+
private Path modelCachePath;
145+
private String modelName;
146+
private String authToken;
147+
private Integer threadCount;
148+
private Path workingDirectory;
149+
private Boolean quantizeModelAtRuntime;
150+
private DType workingQuantizedType;
151+
private Float temperature;
152+
private Integer maxTokens;
153+
154+
public JlamaChatModelBuilder modelCachePath(Path modelCachePath) {
155+
this.modelCachePath = modelCachePath;
156+
return this;
157+
}
158+
159+
public JlamaChatModelBuilder modelName(String modelName) {
160+
this.modelName = modelName;
161+
return this;
162+
}
163+
164+
public JlamaChatModelBuilder authToken(String authToken) {
165+
this.authToken = authToken;
166+
return this;
167+
}
168+
169+
public JlamaChatModelBuilder threadCount(Integer threadCount) {
170+
this.threadCount = threadCount;
171+
return this;
172+
}
173+
174+
public JlamaChatModelBuilder workingDirectory(Path workingDirectory) {
175+
this.workingDirectory = workingDirectory;
176+
return this;
177+
}
178+
179+
public JlamaChatModelBuilder quantizeModelAtRuntime(Boolean quantizeModelAtRuntime) {
180+
this.quantizeModelAtRuntime = quantizeModelAtRuntime;
181+
return this;
182+
}
183+
184+
public JlamaChatModelBuilder workingQuantizedType(DType workingQuantizedType) {
185+
this.workingQuantizedType = workingQuantizedType;
186+
return this;
187+
}
188+
189+
public JlamaChatModelBuilder temperature(Float temperature) {
190+
this.temperature = temperature;
191+
return this;
192+
}
193+
194+
public JlamaChatModelBuilder maxTokens(Integer maxTokens) {
195+
this.maxTokens = maxTokens;
196+
return this;
197+
}
198+
199+
public JlamaChatModel build() {
200+
return new JlamaChatModel(this);
201+
}
202+
}
203+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package io.quarkiverse.langchain4j.jlama;
2+
3+
import java.nio.file.Path;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import java.util.Optional;
7+
8+
import com.github.tjake.jlama.model.AbstractModel;
9+
import com.github.tjake.jlama.model.ModelSupport;
10+
import com.github.tjake.jlama.model.bert.BertModel;
11+
import com.github.tjake.jlama.model.functions.Generator;
12+
13+
import dev.langchain4j.data.embedding.Embedding;
14+
import dev.langchain4j.data.segment.TextSegment;
15+
import dev.langchain4j.internal.RetryUtils;
16+
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
17+
import dev.langchain4j.model.output.Response;
18+
19+
public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
20+
private final BertModel model;
21+
private final Generator.PoolingType poolingType;
22+
23+
public JlamaEmbeddingModel(JlamaEmbeddingModelBuilder builder) {
24+
25+
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(builder.modelCachePath);
26+
JlamaModel jlamaModel = RetryUtils
27+
.withRetry(() -> registry.downloadModel(builder.modelName, Optional.ofNullable(builder.authToken)), 3);
28+
29+
if (jlamaModel.getModelType() != ModelSupport.ModelType.BERT) {
30+
throw new IllegalArgumentException("Model type must be BERT");
31+
}
32+
33+
JlamaModel.Loader loader = jlamaModel.loader();
34+
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime)
35+
loader = loader.quantized();
36+
37+
if (builder.threadCount != null)
38+
loader = loader.threadCount(builder.threadCount);
39+
40+
if (builder.workingDirectory != null)
41+
loader = loader.workingDirectory(builder.workingDirectory);
42+
43+
loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING);
44+
45+
this.model = (BertModel) loader.load();
46+
this.dimension = model.getConfig().embeddingLength;
47+
48+
this.poolingType = builder.poolingType == null ? Generator.PoolingType.MODEL : builder.poolingType;
49+
}
50+
51+
public static JlamaEmbeddingModelBuilder builder() {
52+
return new JlamaEmbeddingModelBuilder();
53+
}
54+
55+
@Override
56+
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
57+
List<Embedding> embeddings = new ArrayList<>();
58+
59+
textSegments.forEach(textSegment -> {
60+
embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType)));
61+
});
62+
63+
return Response.from(embeddings);
64+
}
65+
66+
public static class JlamaEmbeddingModelBuilder {
67+
68+
private Path modelCachePath;
69+
private String modelName;
70+
private String authToken;
71+
private Integer threadCount;
72+
private Path workingDirectory;
73+
private Boolean quantizeModelAtRuntime;
74+
private Generator.PoolingType poolingType;
75+
76+
public JlamaEmbeddingModelBuilder modelCachePath(Path modelCachePath) {
77+
this.modelCachePath = modelCachePath;
78+
return this;
79+
}
80+
81+
public JlamaEmbeddingModelBuilder modelName(String modelName) {
82+
this.modelName = modelName;
83+
return this;
84+
}
85+
86+
public JlamaEmbeddingModelBuilder authToken(String authToken) {
87+
this.authToken = authToken;
88+
return this;
89+
}
90+
91+
public JlamaEmbeddingModelBuilder threadCount(Integer threadCount) {
92+
this.threadCount = threadCount;
93+
return this;
94+
}
95+
96+
public JlamaEmbeddingModelBuilder workingDirectory(Path workingDirectory) {
97+
this.workingDirectory = workingDirectory;
98+
return this;
99+
}
100+
101+
public JlamaEmbeddingModelBuilder quantizeModelAtRuntime(Boolean quantizeModelAtRuntime) {
102+
this.quantizeModelAtRuntime = quantizeModelAtRuntime;
103+
return this;
104+
}
105+
106+
public JlamaEmbeddingModel build() {
107+
return new JlamaEmbeddingModel(this);
108+
}
109+
}
110+
}

0 commit comments

Comments
 (0)