Skip to content

Commit 1c1a2d1

Browse files
committed
Use latest version of the Llama3.java code
1 parent da4251b commit 1c1a2d1

File tree

6 files changed

+269
-155
lines changed

6 files changed

+269
-155
lines changed

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3ChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
9696

9797
private InferenceResponse runInference(Llama model, Sampler sampler, Llama3.Options options,
9898
List<ChatFormat.Message> messages) {
99-
Llama.State state = model.createNewState();
99+
Llama.State state = model.createNewState(Llama3.BATCH_SIZE);
100100
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
101101

102102
List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3StreamingChatModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static dev.langchain4j.data.message.AiMessage.aiMessage;
44
import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message;
5+
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.BATCH_SIZE;
56
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler;
67

78
import java.io.IOException;
@@ -84,7 +85,7 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
8485
private void runInference(Llama model, Sampler sampler, Llama3.Options options,
8586
List<ChatFormat.Message> messages,
8687
StreamingResponseHandler<AiMessage> handler) {
87-
Llama.State state = model.createNewState();
88+
Llama.State state = model.createNewState(BATCH_SIZE);
8889
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
8990

9091
List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/AOT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public static PartialModel preLoadGGUF(String modelPath) {
5151
* No checksum/hash is checked for performance reasons.
5252
*/
5353
public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
54-
PartialModel preLoaded = AOT.PRELOADED_GGUF;
54+
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
5555
if (preLoaded == null) {
5656
return null; // no pre-loaded model stored
5757
}

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/copy/GGUF.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ private void loadModelImpl(FileChannel fileChannel) throws IOException {
108108
// gguf_tensor_info_t tensor_infos[header.tensor_count];
109109
this.tensorInfos = HashMap.newHashMap(tensorCount);
110110
for (int i = 0; i < tensorCount; ++i) {
111-
GGUFTensorInfo ti = readTensorInfo(fileChannel);
111+
GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel);
112112
assert !tensorInfos.containsKey(ti.name);
113113
tensorInfos.put(ti.name, ti);
114114
}
@@ -156,7 +156,7 @@ private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
156156
return GGMLType.fromId(ggmlTypeId);
157157
}
158158

159-
private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
159+
private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
160160
// The name of the tensor. It is a standard GGUF string, with the caveat that
161161
// it must be at most 64 bytes long.
162162
String name = readString(fileChannel); // gguf_string_t name;
@@ -180,7 +180,7 @@ private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOExceptio
180180
// Must be a multiple of `ALIGNMENT`.
181181
long offset = readLong(fileChannel); // uint64_t offset;
182182
assert offset % getAlignment() == 0;
183-
return new GGUFTensorInfo(name, dimensions, ggmlType, offset);
183+
return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset);
184184
}
185185

186186
private String readString(FileChannel fileChannel) throws IOException {

0 commit comments

Comments
 (0)