Skip to content

Commit f4b60ad

Browse files
authored
Merge pull request #1076 from quarkiverse/llama3-updated
Use latest version of the Llama3.java code
2 parents c10b478 + 1c1a2d1 commit f4b60ad

File tree

7 files changed

+280
-155
lines changed

7 files changed

+280
-155
lines changed

docs/modules/ROOT/pages/llama3.adoc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ WARNING: Models are huge, so make sure you have enough disk space.
3939
4040
NOTE: Due to model's large size, pulling them can take time
4141
42+
=== Native mode
43+
44+
Currently, Llama3.java only works in native mode with Early Access version's of Oracle GraalVM 24 (which can be easily downloaded with https://sdkman.io[SDKMan]).
45+
46+
To achieve the best performance in native mode, it is suggested to configure the application with the following:
47+
48+
[source,properties,subs=attributes+]
49+
----
50+
quarkus.native.additional-build-args=-O3,-march=native
51+
----
52+
4253
== Using Llama3.java
4354
4455
To let Llama3.java running inference on your models, add the following dependency into your project:

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
9696

9797
private InferenceResponse runInference(Llama model, Sampler sampler, Llama3.Options options,
9898
List<ChatFormat.Message> messages) {
99-
Llama.State state = model.createNewState();
99+
Llama.State state = model.createNewState(Llama3.BATCH_SIZE);
100100
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
101101

102102
List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3StreamingChatModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static dev.langchain4j.data.message.AiMessage.aiMessage;
44
import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message;
5+
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.BATCH_SIZE;
56
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler;
67

78
import java.io.IOException;
@@ -84,7 +85,7 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
8485
private void runInference(Llama model, Sampler sampler, Llama3.Options options,
8586
List<ChatFormat.Message> messages,
8687
StreamingResponseHandler<AiMessage> handler) {
87-
Llama.State state = model.createNewState();
88+
Llama.State state = model.createNewState(BATCH_SIZE);
8889
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
8990

9091
List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/AOT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public static PartialModel preLoadGGUF(String modelPath) {
5151
* No checksum/hash is checked for performance reasons.
5252
*/
5353
public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
54-
PartialModel preLoaded = AOT.PRELOADED_GGUF;
54+
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
5555
if (preLoaded == null) {
5656
return null; // no pre-loaded model stored
5757
}

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/GGUF.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ private void loadModelImpl(FileChannel fileChannel) throws IOException {
108108
// gguf_tensor_info_t tensor_infos[header.tensor_count];
109109
this.tensorInfos = HashMap.newHashMap(tensorCount);
110110
for (int i = 0; i < tensorCount; ++i) {
111-
GGUFTensorInfo ti = readTensorInfo(fileChannel);
111+
GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel);
112112
assert !tensorInfos.containsKey(ti.name);
113113
tensorInfos.put(ti.name, ti);
114114
}
@@ -156,7 +156,7 @@ private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
156156
return GGMLType.fromId(ggmlTypeId);
157157
}
158158

159-
private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
159+
private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
160160
// The name of the tensor. It is a standard GGUF string, with the caveat that
161161
// it must be at most 64 bytes long.
162162
String name = readString(fileChannel); // gguf_string_t name;
@@ -180,7 +180,7 @@ private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOExceptio
180180
// Must be a multiple of `ALIGNMENT`.
181181
long offset = readLong(fileChannel); // uint64_t offset;
182182
assert offset % getAlignment() == 0;
183-
return new GGUFTensorInfo(name, dimensions, ggmlType, offset);
183+
return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset);
184184
}
185185

186186
private String readString(FileChannel fileChannel) throws IOException {

0 commit comments

Comments
 (0)