Skip to content

Commit 8cfc581

Browse files
committed
add cli
Signed-off-by: caitianchi <caitianchi@modelbest.cn>
1 parent 60934fe commit 8cfc581

File tree

5 files changed

+379
-19
lines changed

5 files changed

+379
-19
lines changed
Binary file not shown.
Binary file not shown.

tools/omni/omni-cli.cpp

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
#include "omni-impl.h"
2+
#include "omni.h"
3+
4+
#include "arg.h"
5+
#include "log.h"
6+
#include "sampling.h"
7+
#include "llama.h"
8+
#include "ggml.h"
9+
#include "console.h"
10+
#include "chat.h"
11+
12+
#include <iostream>
13+
#include <chrono>
14+
#include <vector>
15+
#include <limits.h>
16+
#include <cinttypes>
17+
#include <algorithm>
18+
#include <cstdio>
19+
#include <cstdlib>
20+
#include <cstring>
21+
#include <vector>
22+
23+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
24+
#include <signal.h>
25+
#include <unistd.h>
26+
#elif defined (_WIN32)
27+
#define WIN32_LEAN_AND_MEAN
28+
#ifndef NOMINMAX
29+
#define NOMINMAX
30+
#endif
31+
#include <windows.h>
32+
#include <signal.h>
33+
#endif
34+
35+
// volatile, because of signal being an interrupt
36+
static volatile bool g_is_generating = false;
37+
static volatile bool g_is_interrupted = false;
38+
39+
/**
40+
* Please note that this is NOT a production-ready stuff.
41+
* It is a playground for trying multimodal support in llama.cpp.
42+
* For contributors: please keep this code simple and easy to understand.
43+
*/
44+
45+
static void show_usage(const char * prog_name) {
46+
printf(
47+
"MiniCPM-o Omni CLI - Multimodal inference tool\n\n"
48+
"Usage: %s -m <llm_model_path> [options]\n\n"
49+
"Required:\n"
50+
" -m <path> Path to LLM GGUF model (e.g., MiniCPM-o-4_5-Q4_K_M.gguf)\n"
51+
" Other model paths will be auto-detected from directory structure:\n"
52+
" {dir}/vision/MiniCPM-o-4_5-vision-F16.gguf\n"
53+
" {dir}/audio/MiniCPM-o-4_5-audio-F16.gguf\n"
54+
" {dir}/tts/MiniCPM-o-4_5-tts-F16.gguf\n"
55+
" {dir}/tts/MiniCPM-o-4_5-projector-F16.gguf\n\n"
56+
"Options:\n"
57+
" --vision <path> Override vision model path\n"
58+
" --audio <path> Override audio model path\n"
59+
" --tts <path> Override TTS model path\n"
60+
" --projector <path> Override projector model path\n"
61+
" --ref-audio <path> Reference audio for voice cloning (default: tools/omni/assets/default_ref_audio.wav)\n"
62+
" -c, --ctx-size <n> Context size (default: 4096)\n"
63+
" -ngl <n> Number of GPU layers (default: 99)\n"
64+
" --no-tts Disable TTS output\n"
65+
" --test <prefix> <n> Run test case with audio prefix and count\n"
66+
" -h, --help Show this help message\n\n"
67+
"Example:\n"
68+
" %s -m ./models/MiniCPM-o-4_5-gguf/MiniCPM-o-4_5-Q4_K_M.gguf\n"
69+
" %s -m ./models/MiniCPM-o-4_5-gguf/MiniCPM-o-4_5-F16.gguf --no-tts\n",
70+
prog_name, prog_name, prog_name
71+
);
72+
}
73+
74+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
75+
static void sigint_handler(int signo) {
76+
if (signo == SIGINT) {
77+
if (g_is_generating) {
78+
g_is_generating = false;
79+
} else {
80+
console::cleanup();
81+
if (g_is_interrupted) {
82+
_exit(1);
83+
}
84+
g_is_interrupted = true;
85+
}
86+
}
87+
}
88+
#endif
89+
90+
// 从 LLM 模型路径推断其他模型路径
91+
// 目录结构:
92+
// MiniCPM-o-4_5-gguf/
93+
// ├── MiniCPM-o-4_5-{量化}.gguf (LLM)
94+
// ├── audio/
95+
// │ └── MiniCPM-o-4_5-audio-F16.gguf
96+
// ├── tts/
97+
// │ ├── MiniCPM-o-4_5-projector-F16.gguf
98+
// │ └── MiniCPM-o-4_5-tts-F16.gguf
99+
// └── vision/
100+
// └── MiniCPM-o-4_5-vision-F16.gguf
101+
struct OmniModelPaths {
102+
std::string llm; // LLM 模型路径
103+
std::string vision; // 视觉模型路径
104+
std::string audio; // 音频模型路径
105+
std::string tts; // TTS 模型路径
106+
std::string projector; // Projector 模型路径
107+
std::string base_dir; // 模型根目录
108+
};
109+
110+
static std::string get_parent_dir(const std::string & path) {
111+
size_t last_slash = path.find_last_of("/\\");
112+
if (last_slash != std::string::npos) {
113+
return path.substr(0, last_slash);
114+
}
115+
return ".";
116+
}
117+
118+
static bool file_exists(const std::string & path) {
119+
FILE * f = fopen(path.c_str(), "rb");
120+
if (f) {
121+
fclose(f);
122+
return true;
123+
}
124+
return false;
125+
}
126+
127+
static OmniModelPaths resolve_model_paths(const std::string & llm_path) {
128+
OmniModelPaths paths;
129+
paths.llm = llm_path;
130+
paths.base_dir = get_parent_dir(llm_path);
131+
132+
// 自动推断其他模型路径
133+
paths.vision = paths.base_dir + "/vision/MiniCPM-o-4_5-vision-F16.gguf";
134+
paths.audio = paths.base_dir + "/audio/MiniCPM-o-4_5-audio-F16.gguf";
135+
paths.tts = paths.base_dir + "/tts/MiniCPM-o-4_5-tts-F16.gguf";
136+
paths.projector = paths.base_dir + "/tts/MiniCPM-o-4_5-projector-F16.gguf";
137+
138+
return paths;
139+
}
140+
141+
static void print_model_paths(const OmniModelPaths & paths) {
142+
printf("=== Model Paths ===\n");
143+
printf(" Base dir: %s\n", paths.base_dir.c_str());
144+
printf(" LLM: %s %s\n", paths.llm.c_str(), file_exists(paths.llm) ? "[OK]" : "[NOT FOUND]");
145+
printf(" Vision: %s %s\n", paths.vision.c_str(), file_exists(paths.vision) ? "[OK]" : "[NOT FOUND]");
146+
printf(" Audio: %s %s\n", paths.audio.c_str(), file_exists(paths.audio) ? "[OK]" : "[NOT FOUND]");
147+
printf(" TTS: %s %s\n", paths.tts.c_str(), file_exists(paths.tts) ? "[OK]" : "[NOT FOUND]");
148+
printf(" Projector: %s %s\n", paths.projector.c_str(), file_exists(paths.projector) ? "[OK]" : "[NOT FOUND]");
149+
printf("===================\n");
150+
}
151+
152+
void test_case(struct omni_context *ctx_omni, common_params& params, std::string audio_path_prefix, int cnt){
153+
for (int il = 0; il < cnt; ++il) {
154+
char idx_str[16];
155+
snprintf(idx_str, sizeof(idx_str), "%04d", il); // 格式化为4位数字,如 0000, 0001
156+
std::string aud_fname = audio_path_prefix + idx_str + ".wav";
157+
158+
auto t0 = std::chrono::high_resolution_clock::now();
159+
stream_prefill(ctx_omni, aud_fname, "", il + 1);
160+
auto t1 = std::chrono::high_resolution_clock::now();
161+
std::chrono::duration<double> elapsed_seconds = t1 - t0;
162+
double dt = elapsed_seconds.count();
163+
std::cout << "prefill " << il << " : " << dt << " s"<< std::endl;
164+
}
165+
stream_decode(ctx_omni, "./");
166+
}
167+
168+
int main(int argc, char ** argv) {
169+
ggml_time_init();
170+
171+
// 命令行参数
172+
std::string llm_path;
173+
std::string vision_path_override;
174+
std::string audio_path_override;
175+
std::string tts_path_override;
176+
std::string projector_path_override;
177+
std::string ref_audio_path = "tools/omni/assets/default_ref_audio.wav"; // 默认参考音频
178+
int n_ctx = 4096;
179+
int n_gpu_layers = 99; // GPU 层数,默认 99
180+
bool use_tts = true;
181+
bool run_test = false;
182+
std::string test_audio_prefix;
183+
int test_count = 0;
184+
185+
// 解析命令行参数
186+
for (int i = 1; i < argc; i++) {
187+
std::string arg = argv[i];
188+
189+
if (arg == "-h" || arg == "--help") {
190+
show_usage(argv[0]);
191+
return 0;
192+
}
193+
else if (arg == "-m" && i + 1 < argc) {
194+
llm_path = argv[++i];
195+
}
196+
else if (arg == "--vision" && i + 1 < argc) {
197+
vision_path_override = argv[++i];
198+
}
199+
else if (arg == "--audio" && i + 1 < argc) {
200+
audio_path_override = argv[++i];
201+
}
202+
else if (arg == "--tts" && i + 1 < argc) {
203+
tts_path_override = argv[++i];
204+
}
205+
else if (arg == "--projector" && i + 1 < argc) {
206+
projector_path_override = argv[++i];
207+
}
208+
else if (arg == "--ref-audio" && i + 1 < argc) {
209+
ref_audio_path = argv[++i];
210+
}
211+
else if ((arg == "-c" || arg == "--ctx-size") && i + 1 < argc) {
212+
n_ctx = std::atoi(argv[++i]);
213+
}
214+
else if (arg == "-ngl" && i + 1 < argc) {
215+
n_gpu_layers = std::atoi(argv[++i]);
216+
}
217+
else if (arg == "--no-tts") {
218+
use_tts = false;
219+
}
220+
else if (arg == "--test" && i + 2 < argc) {
221+
run_test = true;
222+
test_audio_prefix = argv[++i];
223+
test_count = std::atoi(argv[++i]);
224+
}
225+
else {
226+
fprintf(stderr, "Unknown argument: %s\n", arg.c_str());
227+
show_usage(argv[0]);
228+
return 1;
229+
}
230+
}
231+
232+
// 检查必需参数
233+
if (llm_path.empty()) {
234+
fprintf(stderr, "Error: -m <llm_model_path> is required\n\n");
235+
show_usage(argv[0]);
236+
return 1;
237+
}
238+
239+
// 解析模型路径
240+
OmniModelPaths paths = resolve_model_paths(llm_path);
241+
242+
// 应用覆盖路径
243+
if (!vision_path_override.empty()) paths.vision = vision_path_override;
244+
if (!audio_path_override.empty()) paths.audio = audio_path_override;
245+
if (!tts_path_override.empty()) paths.tts = tts_path_override;
246+
if (!projector_path_override.empty()) paths.projector = projector_path_override;
247+
248+
// 打印模型路径
249+
print_model_paths(paths);
250+
251+
// 检查必需文件
252+
if (!file_exists(paths.llm)) {
253+
fprintf(stderr, "Error: LLM model not found: %s\n", paths.llm.c_str());
254+
return 1;
255+
}
256+
if (!file_exists(paths.audio)) {
257+
fprintf(stderr, "Error: Audio model not found: %s\n", paths.audio.c_str());
258+
return 1;
259+
}
260+
if (use_tts && !file_exists(paths.tts)) {
261+
fprintf(stderr, "Warning: TTS model not found: %s, disabling TTS\n", paths.tts.c_str());
262+
use_tts = false;
263+
}
264+
265+
// 设置参数
266+
common_params params;
267+
params.model.path = paths.llm;
268+
params.vpm_model = paths.vision;
269+
params.apm_model = paths.audio;
270+
params.tts_model = paths.tts;
271+
params.n_ctx = n_ctx;
272+
params.n_gpu_layers = n_gpu_layers;
273+
274+
// Projector 路径需要通过 tts_bin_dir 传递
275+
// omni.cpp 中 projector 路径计算: gguf_root_dir + "/projector.gguf"
276+
// 其中 gguf_root_dir = tts_bin_dir 的父目录
277+
// 但我们的结构是 projector 在 tts/ 目录下
278+
// 所以需要修改 omni.cpp 或者创建符号链接
279+
// 这里暂时使用 tts 目录作为 tts_bin_dir
280+
std::string tts_bin_dir = get_parent_dir(paths.tts);
281+
282+
common_init();
283+
284+
printf("=== Initializing Omni Context ===\n");
285+
printf(" TTS enabled: %s\n", use_tts ? "yes" : "no");
286+
printf(" Context size: %d\n", n_ctx);
287+
printf(" GPU layers: %d\n", n_gpu_layers);
288+
printf(" TTS bin dir: %s\n", tts_bin_dir.c_str());
289+
printf(" Ref audio: %s\n", ref_audio_path.c_str());
290+
291+
auto ctx_omni = omni_init(&params, 1, use_tts, tts_bin_dir, -1, "gpu:0");
292+
if (ctx_omni == nullptr) {
293+
fprintf(stderr, "Error: Failed to initialize omni context\n");
294+
return 1;
295+
}
296+
ctx_omni->async = true;
297+
ctx_omni->ref_audio_path = ref_audio_path; // 设置参考音频路径
298+
299+
if (run_test) {
300+
printf("=== Running test case ===\n");
301+
printf(" Audio prefix: %s\n", test_audio_prefix.c_str());
302+
printf(" Count: %d\n", test_count);
303+
test_case(ctx_omni, params, test_audio_prefix, test_count);
304+
} else {
305+
// 默认测试用例
306+
test_case(ctx_omni, params, std::string("tools/omni/assets/test_case/audio_test_case/audio_test_case_"), 2);
307+
}
308+
309+
if(ctx_omni->async && ctx_omni->use_tts){
310+
if(ctx_omni->tts_thread.joinable()) {
311+
ctx_omni->tts_thread.join();
312+
printf("tts end\n");
313+
}
314+
}
315+
316+
llama_perf_context_print(ctx_omni->ctx_llama);
317+
318+
omni_free(ctx_omni);
319+
return 0;
320+
}

0 commit comments

Comments
 (0)