Skip to content

Commit 14a922c

Browse files
committed
[aux] Common test
Add a submodule with re-usable code for tests.
1 parent cac2b24 commit 14a922c

File tree

4 files changed

+249
-0
lines changed

4 files changed

+249
-0
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ if (LLAMA_BUILD_COMMON)
193193
add_subdirectory(common)
194194
endif()
195195

196+
if(LLAMA_BUILD_EXAMPLES OR LLAMA_BUILD_TESTS)
197+
add_subdirectory(common_test)
198+
endif()
199+
196200
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
197201
include(CTest)
198202
add_subdirectory(tests)

common_test/CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# common_test library for load_into_memory.h and uint8-buff-stream.h
2+
3+
set(TARGET llama-common-test)
4+
5+
# Create an interface library (header-only)
6+
add_library(${TARGET} INTERFACE)
7+
8+
# Add the include directories
9+
target_include_directories(${TARGET} INTERFACE
10+
${CMAKE_CURRENT_SOURCE_DIR}
11+
)
12+
13+
# Set a compile definition to indicate this target provides the specific header
14+
target_compile_definitions(${TARGET} INTERFACE LLAMA_COMMON_TEST_HEADERS)
15+
16+
# Set compile features
17+
target_compile_features(${TARGET} INTERFACE cxx_std_17)
18+
19+
# Link with common library
20+
target_link_libraries(${TARGET} INTERFACE common)

common_test/load_into_memory.h

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#pragma once
2+
3+
#include <chrono>
4+
#include <cstdint>
5+
#include <cstdio>
6+
#include <cstring>
7+
#include <ctime>
8+
#include <fstream>
9+
#include <memory>
10+
#include <sstream>
11+
#include <streambuf>
12+
#include <string>
13+
#include <thread>
14+
#include <vector>
15+
16+
// header-only utilities to showcase how to directly load a model from memory
17+
#include "uint8-buff-stream-wrapper.h"
18+
19+
namespace {
20+
bool is_split_file(const char * const model_path) {
21+
if (!model_path) {
22+
fprintf(stderr, "No model file provided\n");
23+
exit(EXIT_FAILURE);
24+
}
25+
26+
std::string path(model_path);
27+
return path.find("-of-") != std::string::npos;
28+
}
29+
30+
std::vector<uint8_t> load_file_into_buffer(const char * const model_path) {
31+
std::ifstream file_stream(model_path, std::ios::binary | std::ios::ate);
32+
if (!file_stream) {
33+
fprintf(stderr, "Failed to open file %s for reading into streambuf\n", model_path);
34+
exit(EXIT_FAILURE);
35+
}
36+
37+
const size_t file_size = file_stream.tellg();
38+
file_stream.seekg(0, std::ios::beg);
39+
40+
static_assert(sizeof(std::uint8_t) == sizeof(char), "uint8_t must be same size as char");
41+
std::vector<std::uint8_t> buffer(file_size);
42+
if (!file_stream.read((char *) buffer.data(), file_size)) {
43+
fprintf(stderr, "Failed to read entire file into buffer\n");
44+
exit(EXIT_FAILURE);
45+
}
46+
47+
return buffer;
48+
}
49+
50+
std::unique_ptr<std::basic_streambuf<uint8_t>> load_file_into_streambuf(const char * const model_path) {
51+
return std::make_unique<Uint8BufferStreamBuf>(load_file_into_buffer(model_path));
52+
}
53+
54+
struct file_entry {
55+
std::string path;
56+
std::unique_ptr<std::basic_streambuf<uint8_t>> streambuf;
57+
};
58+
59+
std::vector<file_entry> load_files_into_streambuf(const char * const model_path) {
60+
std::vector<file_entry> files;
61+
62+
// Extract pattern from first file path
63+
std::string path(model_path);
64+
65+
// Split by '-'
66+
std::vector<std::string> parts;
67+
std::stringstream ss(path);
68+
std::string item;
69+
while (std::getline(ss, item, '-')) {
70+
parts.push_back(item);
71+
}
72+
73+
// Split the last part by '.'
74+
std::string last_part = parts.back();
75+
parts.pop_back();
76+
size_t dot_pos = last_part.find('.');
77+
if (dot_pos != std::string::npos) {
78+
parts.push_back(last_part.substr(0, dot_pos));
79+
parts.push_back(last_part.substr(dot_pos + 1)); // extension
80+
} else {
81+
parts.push_back(last_part);
82+
}
83+
84+
// Check if we have enough parts
85+
if (parts.size() < 4) {
86+
fprintf(stderr, "Model path does not contain expected pattern\n");
87+
exit(EXIT_FAILURE);
88+
}
89+
90+
// Get total files from [-2] position (before the extension)
91+
int total_files = std::stoi(parts[parts.size() - 2]);
92+
93+
// Get base path by joining all parts except -start-of-end.gguf
94+
std::string base_path;
95+
for (size_t i = 0; i < parts.size() - 4; i++) {
96+
if (i > 0) {
97+
base_path += "-";
98+
}
99+
base_path += parts[i];
100+
}
101+
102+
for (int i = 1; i <= total_files; i++) {
103+
char numbered_path[1024];
104+
snprintf(numbered_path, sizeof(numbered_path), "%s-%05d-of-%05d.gguf", base_path.c_str(), i, total_files);
105+
106+
files.push_back({ numbered_path, load_file_into_streambuf(numbered_path) });
107+
}
108+
109+
return files;
110+
}
111+
112+
file_entry load_tensor_list_file(const char * const model_path) {
113+
std::string path(model_path);
114+
115+
// Split by '-'
116+
std::vector<std::string> parts;
117+
std::stringstream ss(path);
118+
std::string item;
119+
while (std::getline(ss, item, '-')) {
120+
parts.push_back(item);
121+
}
122+
123+
// Split the last part by '.'
124+
std::string last_part = parts.back();
125+
parts.pop_back();
126+
size_t dot_pos = last_part.find('.');
127+
if (dot_pos != std::string::npos) {
128+
parts.push_back(last_part.substr(0, dot_pos));
129+
parts.push_back(last_part.substr(dot_pos + 1)); // extension
130+
} else {
131+
parts.push_back(last_part);
132+
}
133+
134+
// Check if we have enough parts
135+
if (parts.size() < 4) {
136+
fprintf(stderr, "Model path does not contain expected pattern\n");
137+
exit(EXIT_FAILURE);
138+
}
139+
140+
// Get base path by joining all parts except -start-of-end.gguf
141+
std::string base_path;
142+
for (size_t i = 0; i < parts.size() - 4; i++) {
143+
if (i > 0) {
144+
base_path += "-";
145+
}
146+
base_path += parts[i];
147+
}
148+
149+
// Construct tensor list file path
150+
std::string tensor_list_path = base_path + ".tensors.txt";
151+
152+
printf("Loading tensor list file: %s\n", tensor_list_path.c_str());
153+
return { tensor_list_path, load_file_into_streambuf(tensor_list_path.c_str()) };
154+
}
155+
156+
llama_model * load_model_from_memory_configuration(const char * model_path, llama_model_params & model_params) {
157+
llama_model * model;
158+
std::chrono::steady_clock::time_point load_start_time;
159+
if (getenv("LLAMA_EXAMPLE_MEMORY_BUFFER")) {
160+
std::vector<uint8_t> buffer = load_file_into_buffer(model_path);
161+
fprintf(stdout, "%s: loading model from memory buffer\n", __func__);
162+
load_start_time = std::chrono::steady_clock::now();
163+
model = llama_model_load_from_buffer(std::move(buffer), model_params);
164+
} else if (getenv("LLAMA_EXAMPLE_MEMORY_BUFFER_SPLIT")) {
165+
file_entry tensor_list_file = load_tensor_list_file(model_path);
166+
std::vector<file_entry> files = load_files_into_streambuf(model_path);
167+
fprintf(stdout, "%s: loading model from %zu file streambufs\n", __func__, files.size());
168+
169+
std::vector<const char *> file_paths;
170+
for (const auto & file : files) {
171+
printf("Found file %s with streambuf\n", file.path.c_str());
172+
file_paths.push_back(file.path.c_str());
173+
}
174+
175+
load_start_time = std::chrono::steady_clock::now();
176+
const char * async_load_context = "test-model-load";
177+
std::thread fulfill_thread([&files, &tensor_list_file, &async_load_context]() {
178+
const bool success = llama_model_load_fulfill_split_future(
179+
tensor_list_file.path.c_str(), async_load_context, std::move(tensor_list_file.streambuf));
180+
printf("Fulfilling tensor list file %s: %s\n", tensor_list_file.path.c_str(),
181+
success ? "success" : "failure");
182+
if (!success) {
183+
exit(EXIT_FAILURE);
184+
}
185+
186+
for (auto & file : files) {
187+
const bool success = llama_model_load_fulfill_split_future(file.path.c_str(), async_load_context,
188+
std::move(file.streambuf));
189+
printf("Fulfilling file %s with streambuf: %s\n", file.path.c_str(), success ? "success" : "failure");
190+
if (!success) {
191+
exit(EXIT_FAILURE);
192+
}
193+
}
194+
});
195+
fprintf(stderr, "Loading model from splits\n");
196+
model = llama_model_load_from_split_futures(file_paths.data(), file_paths.size(), async_load_context,
197+
tensor_list_file.path.c_str(), model_params);
198+
fulfill_thread.join();
199+
} else if (getenv("LLAMA_EXAMPLE_FROM_FILE")) {
200+
load_start_time = std::chrono::steady_clock::now();
201+
model = llama_model_load_from_file(model_path, model_params);
202+
} else {
203+
return nullptr;
204+
}
205+
206+
if (model == NULL) {
207+
fprintf(stderr, "%s: error: unable to load model\n", __func__);
208+
exit(1);
209+
}
210+
std::chrono::steady_clock::time_point load_end_time = std::chrono::steady_clock::now();
211+
std::chrono::duration<double> load_duration = load_end_time - load_start_time;
212+
fprintf(stdout, "%s: loading model took %f seconds\n", __func__, load_duration.count());
213+
return model;
214+
}
215+
216+
bool memory_configuration_env_is_set() {
217+
return getenv("LLAMA_EXAMPLE_MEMORY_BUFFER") || getenv("LLAMA_EXAMPLE_MEMORY_BUFFER_SPLIT") ||
218+
getenv("LLAMA_EXAMPLE_FROM_FILE");
219+
}
220+
} // namespace
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
// Wrapper to include the specific header from src
4+
#include "../src/uint8-buff-stream.h"
5+

0 commit comments

Comments
 (0)