Skip to content

Commit de1e3df

Browse files
committed
[aux] Common test
Add a submodule with re-usable code for tests.
1 parent f5eebff commit de1e3df

File tree

4 files changed

+269
-0
lines changed

4 files changed

+269
-0
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ if (LLAMA_BUILD_COMMON)
193193
add_subdirectory(common)
194194
endif()
195195

196+
if(LLAMA_BUILD_EXAMPLES OR LLAMA_BUILD_TESTS)
197+
add_subdirectory(common_test)
198+
endif()
199+
196200
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
197201
include(CTest)
198202
add_subdirectory(tests)

common_test/CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# common_test library for load_into_memory.h and uint8-buff-stream.h
2+
3+
set(TARGET llama-common-test)
4+
5+
# Create an interface library (header-only)
6+
add_library(${TARGET} INTERFACE)
7+
8+
# Add the include directories
9+
target_include_directories(${TARGET} INTERFACE
10+
${CMAKE_CURRENT_SOURCE_DIR}
11+
)
12+
13+
# Set a compile definition to indicate this target provides the specific header
14+
target_compile_definitions(${TARGET} INTERFACE LLAMA_COMMON_TEST_HEADERS)
15+
16+
# Set compile features
17+
target_compile_features(${TARGET} INTERFACE cxx_std_17)
18+
19+
# Link with common library
20+
target_link_libraries(${TARGET} INTERFACE common)

common_test/load_into_memory.h

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#pragma once
2+
3+
#include <chrono>
4+
#include <cstdint>
5+
#include <cstdio>
6+
#include <cstring>
7+
#include <ctime>
8+
#include <fstream>
9+
#include <memory>
10+
#include <random>
11+
#include <sstream>
12+
#include <streambuf>
13+
#include <string>
14+
#include <thread>
15+
#include <vector>
16+
17+
// header-only utilities to showcase how to directly load a model from memory
18+
#include "uint8-buff-stream-wrapper.h"
19+
20+
namespace {
21+
bool is_split_file(const char * const model_path) {
22+
if (!model_path) {
23+
fprintf(stderr, "No model file provided\n");
24+
exit(EXIT_FAILURE);
25+
}
26+
27+
std::string path(model_path);
28+
return path.find("-of-") != std::string::npos;
29+
}
30+
31+
std::vector<uint8_t> load_file_into_buffer(const char * const model_path) {
32+
std::ifstream file_stream(model_path, std::ios::binary | std::ios::ate);
33+
if (!file_stream) {
34+
fprintf(stderr, "Failed to open file %s for reading into streambuf\n", model_path);
35+
exit(EXIT_FAILURE);
36+
}
37+
38+
const size_t file_size = file_stream.tellg();
39+
file_stream.seekg(0, std::ios::beg);
40+
41+
static_assert(sizeof(std::uint8_t) == sizeof(char), "uint8_t must be same size as char");
42+
std::vector<std::uint8_t> buffer(file_size);
43+
if (!file_stream.read((char *) buffer.data(), file_size)) {
44+
fprintf(stderr, "Failed to read entire file into buffer\n");
45+
exit(EXIT_FAILURE);
46+
}
47+
48+
return buffer;
49+
}
50+
51+
std::unique_ptr<std::basic_streambuf<uint8_t>> load_file_into_streambuf(const char * const model_path) {
52+
return std::make_unique<Uint8BufferStreamBuf>(load_file_into_buffer(model_path));
53+
}
54+
55+
struct file_entry {
56+
std::string path;
57+
std::unique_ptr<std::basic_streambuf<uint8_t>> streambuf;
58+
};
59+
60+
std::vector<file_entry> load_files_into_streambuf(const char * const model_path) {
61+
std::vector<file_entry> files;
62+
63+
// Extract pattern from first file path
64+
std::string path(model_path);
65+
66+
// Split by '-'
67+
std::vector<std::string> parts;
68+
std::stringstream ss(path);
69+
std::string item;
70+
while (std::getline(ss, item, '-')) {
71+
parts.push_back(item);
72+
}
73+
74+
// Split the last part by '.'
75+
std::string last_part = parts.back();
76+
parts.pop_back();
77+
size_t dot_pos = last_part.find('.');
78+
if (dot_pos != std::string::npos) {
79+
parts.push_back(last_part.substr(0, dot_pos));
80+
parts.push_back(last_part.substr(dot_pos + 1)); // extension
81+
} else {
82+
parts.push_back(last_part);
83+
}
84+
85+
// Check if we have enough parts
86+
if (parts.size() < 4) {
87+
fprintf(stderr, "Model path does not contain expected pattern\n");
88+
exit(EXIT_FAILURE);
89+
}
90+
91+
// Get total files from [-2] position (before the extension)
92+
int total_files = std::stoi(parts[parts.size() - 2]);
93+
94+
// Get base path by joining all parts except -start-of-end.gguf
95+
std::string base_path;
96+
for (size_t i = 0; i < parts.size() - 4; i++) {
97+
if (i > 0) {
98+
base_path += "-";
99+
}
100+
base_path += parts[i];
101+
}
102+
103+
for (int i = 1; i <= total_files; i++) {
104+
char numbered_path[1024];
105+
snprintf(numbered_path, sizeof(numbered_path), "%s-%05d-of-%05d.gguf", base_path.c_str(), i, total_files);
106+
107+
files.push_back({ numbered_path, load_file_into_streambuf(numbered_path) });
108+
}
109+
110+
return files;
111+
}
112+
113+
file_entry load_tensor_list_file(const char * const model_path) {
114+
std::string path(model_path);
115+
116+
// Split by '-'
117+
std::vector<std::string> parts;
118+
std::stringstream ss(path);
119+
std::string item;
120+
while (std::getline(ss, item, '-')) {
121+
parts.push_back(item);
122+
}
123+
124+
// Split the last part by '.'
125+
std::string last_part = parts.back();
126+
parts.pop_back();
127+
size_t dot_pos = last_part.find('.');
128+
if (dot_pos != std::string::npos) {
129+
parts.push_back(last_part.substr(0, dot_pos));
130+
parts.push_back(last_part.substr(dot_pos + 1)); // extension
131+
} else {
132+
parts.push_back(last_part);
133+
}
134+
135+
// Check if we have enough parts
136+
if (parts.size() < 4) {
137+
fprintf(stderr, "Model path does not contain expected pattern\n");
138+
exit(EXIT_FAILURE);
139+
}
140+
141+
// Get base path by joining all parts except -start-of-end.gguf
142+
std::string base_path;
143+
for (size_t i = 0; i < parts.size() - 4; i++) {
144+
if (i > 0) {
145+
base_path += "-";
146+
}
147+
base_path += parts[i];
148+
}
149+
150+
// Construct tensor list file path
151+
std::string tensor_list_path = base_path + ".tensors.txt";
152+
153+
printf("Loading tensor list file: %s\n", tensor_list_path.c_str());
154+
return { tensor_list_path, load_file_into_streambuf(tensor_list_path.c_str()) };
155+
}
156+
157+
llama_model * load_model_from_memory_configuration(const char * model_path, llama_model_params & model_params) {
158+
llama_model * model;
159+
std::chrono::steady_clock::time_point load_start_time;
160+
if (getenv("LLAMA_EXAMPLE_MEMORY_BUFFER")) {
161+
std::vector<uint8_t> buffer = load_file_into_buffer(model_path);
162+
fprintf(stdout, "%s: loading model from memory buffer\n", __func__);
163+
load_start_time = std::chrono::steady_clock::now();
164+
model = llama_model_load_from_buffer(std::move(buffer), model_params);
165+
} else if (getenv("LLAMA_EXAMPLE_MEMORY_BUFFER_SPLIT")) {
166+
file_entry tensor_list_file = load_tensor_list_file(model_path);
167+
std::vector<file_entry> files = load_files_into_streambuf(model_path);
168+
fprintf(stdout, "%s: loading model from %zu file streambufs\n", __func__, files.size());
169+
170+
std::vector<const char *> file_paths;
171+
for (const auto & file : files) {
172+
printf("Found file %s with streambuf\n", file.path.c_str());
173+
file_paths.push_back(file.path.c_str());
174+
}
175+
176+
load_start_time = std::chrono::steady_clock::now();
177+
const char * async_load_context = "test-model-load";
178+
std::thread fulfill_thread([&files, &tensor_list_file, &async_load_context]() {
179+
const bool success = llama_model_load_fulfill_split_future(
180+
tensor_list_file.path.c_str(), async_load_context, std::move(tensor_list_file.streambuf));
181+
printf("Fulfilling tensor list file %s: %s\n", tensor_list_file.path.c_str(),
182+
success ? "success" : "failure");
183+
if (!success) {
184+
exit(EXIT_FAILURE);
185+
}
186+
187+
for (auto & file : files) {
188+
const bool success = llama_model_load_fulfill_split_future(file.path.c_str(), async_load_context,
189+
std::move(file.streambuf));
190+
printf("Fulfilling file %s with streambuf: %s\n", file.path.c_str(), success ? "success" : "failure");
191+
if (!success) {
192+
exit(EXIT_FAILURE);
193+
}
194+
}
195+
});
196+
fprintf(stderr, "Loading model from splits\n");
197+
model = llama_model_load_from_split_futures(file_paths.data(), file_paths.size(), async_load_context,
198+
tensor_list_file.path.c_str(), model_params);
199+
fulfill_thread.join();
200+
} else if (getenv("LLAMA_EXAMPLE_FROM_FILE")) {
201+
load_start_time = std::chrono::steady_clock::now();
202+
model = llama_model_load_from_file(model_path, model_params);
203+
} else {
204+
return nullptr;
205+
}
206+
207+
if (model == NULL) {
208+
fprintf(stderr, "%s: error: unable to load model\n", __func__);
209+
exit(1);
210+
}
211+
std::chrono::steady_clock::time_point load_end_time = std::chrono::steady_clock::now();
212+
std::chrono::duration<double> load_duration = load_end_time - load_start_time;
213+
fprintf(stdout, "%s: loading model took %f seconds\n", __func__, load_duration.count());
214+
return model;
215+
}
216+
217+
bool memory_configuration_env_is_set() {
218+
return getenv("LLAMA_EXAMPLE_MEMORY_BUFFER") || getenv("LLAMA_EXAMPLE_MEMORY_BUFFER_SPLIT") ||
219+
getenv("LLAMA_EXAMPLE_FROM_FILE");
220+
}
221+
222+
std::vector<std::vector<uint8_t>> split_into_random_blobs(const std::vector<uint8_t> & original_data,
223+
std::mt19937 & gen,
224+
std::uniform_int_distribution<> & size_dist) {
225+
std::vector<std::vector<uint8_t>> blobs;
226+
size_t current_pos = 0;
227+
228+
while (current_pos < original_data.size()) {
229+
size_t blob_size = std::min(static_cast<size_t>(size_dist(gen)), original_data.size() - current_pos);
230+
231+
std::vector<uint8_t> blob(original_data.begin() + current_pos, original_data.begin() + current_pos + blob_size);
232+
blobs.push_back(std::move(blob));
233+
current_pos += blob_size;
234+
}
235+
236+
std::cout << "Created " << blobs.size() << " blobs" << std::endl;
237+
238+
return blobs;
239+
}
240+
} // namespace
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
// Wrapper to include the specific header from src
4+
#include "../src/uint8-buff-stream.h"
5+

0 commit comments

Comments
 (0)