|
| 1 | +#include "omni-impl.h" |
| 2 | +#include "omni.h" |
| 3 | + |
| 4 | +#include "arg.h" |
| 5 | +#include "log.h" |
| 6 | +#include "sampling.h" |
| 7 | +#include "llama.h" |
| 8 | +#include "ggml.h" |
| 9 | +#include "console.h" |
| 10 | +#include "chat.h" |
| 11 | + |
| 12 | +#include <iostream> |
| 13 | +#include <chrono> |
| 14 | +#include <vector> |
| 15 | +#include <limits.h> |
| 16 | +#include <cinttypes> |
| 17 | +#include <algorithm> |
| 18 | +#include <cstdio> |
| 19 | +#include <cstdlib> |
| 20 | +#include <cstring> |
| 21 | +#include <vector> |
| 22 | + |
| 23 | +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) |
| 24 | +#include <signal.h> |
| 25 | +#include <unistd.h> |
| 26 | +#elif defined (_WIN32) |
| 27 | +#define WIN32_LEAN_AND_MEAN |
| 28 | +#ifndef NOMINMAX |
| 29 | +#define NOMINMAX |
| 30 | +#endif |
| 31 | +#include <windows.h> |
| 32 | +#include <signal.h> |
| 33 | +#endif |
| 34 | + |
| 35 | +// volatile, because of signal being an interrupt |
| 36 | +static volatile bool g_is_generating = false; |
| 37 | +static volatile bool g_is_interrupted = false; |
| 38 | + |
| 39 | +/** |
| 40 | + * Please note that this is NOT a production-ready stuff. |
| 41 | + * It is a playground for trying multimodal support in llama.cpp. |
| 42 | + * For contributors: please keep this code simple and easy to understand. |
| 43 | + */ |
| 44 | + |
| 45 | +static void show_usage(const char * prog_name) { |
| 46 | + printf( |
| 47 | + "MiniCPM-o Omni CLI - Multimodal inference tool\n\n" |
| 48 | + "Usage: %s -m <llm_model_path> [options]\n\n" |
| 49 | + "Required:\n" |
| 50 | + " -m <path> Path to LLM GGUF model (e.g., MiniCPM-o-4_5-Q4_K_M.gguf)\n" |
| 51 | + " Other model paths will be auto-detected from directory structure:\n" |
| 52 | + " {dir}/vision/MiniCPM-o-4_5-vision-F16.gguf\n" |
| 53 | + " {dir}/audio/MiniCPM-o-4_5-audio-F16.gguf\n" |
| 54 | + " {dir}/tts/MiniCPM-o-4_5-tts-F16.gguf\n" |
| 55 | + " {dir}/tts/MiniCPM-o-4_5-projector-F16.gguf\n\n" |
| 56 | + "Options:\n" |
| 57 | + " --vision <path> Override vision model path\n" |
| 58 | + " --audio <path> Override audio model path\n" |
| 59 | + " --tts <path> Override TTS model path\n" |
| 60 | + " --projector <path> Override projector model path\n" |
| 61 | + " --ref-audio <path> Reference audio for voice cloning (default: tools/omni/assets/default_ref_audio.wav)\n" |
| 62 | + " -c, --ctx-size <n> Context size (default: 4096)\n" |
| 63 | + " -ngl <n> Number of GPU layers (default: 99)\n" |
| 64 | + " --no-tts Disable TTS output\n" |
| 65 | + " --test <prefix> <n> Run test case with audio prefix and count\n" |
| 66 | + " -h, --help Show this help message\n\n" |
| 67 | + "Example:\n" |
| 68 | + " %s -m ./models/MiniCPM-o-4_5-gguf/MiniCPM-o-4_5-Q4_K_M.gguf\n" |
| 69 | + " %s -m ./models/MiniCPM-o-4_5-gguf/MiniCPM-o-4_5-F16.gguf --no-tts\n", |
| 70 | + prog_name, prog_name, prog_name |
| 71 | + ); |
| 72 | +} |
| 73 | + |
| 74 | +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) |
| 75 | +static void sigint_handler(int signo) { |
| 76 | + if (signo == SIGINT) { |
| 77 | + if (g_is_generating) { |
| 78 | + g_is_generating = false; |
| 79 | + } else { |
| 80 | + console::cleanup(); |
| 81 | + if (g_is_interrupted) { |
| 82 | + _exit(1); |
| 83 | + } |
| 84 | + g_is_interrupted = true; |
| 85 | + } |
| 86 | + } |
| 87 | +} |
| 88 | +#endif |
| 89 | + |
| 90 | +// 从 LLM 模型路径推断其他模型路径 |
| 91 | +// 目录结构: |
| 92 | +// MiniCPM-o-4_5-gguf/ |
| 93 | +// ├── MiniCPM-o-4_5-{量化}.gguf (LLM) |
| 94 | +// ├── audio/ |
| 95 | +// │ └── MiniCPM-o-4_5-audio-F16.gguf |
| 96 | +// ├── tts/ |
| 97 | +// │ ├── MiniCPM-o-4_5-projector-F16.gguf |
| 98 | +// │ └── MiniCPM-o-4_5-tts-F16.gguf |
| 99 | +// └── vision/ |
| 100 | +// └── MiniCPM-o-4_5-vision-F16.gguf |
| 101 | +struct OmniModelPaths { |
| 102 | + std::string llm; // LLM 模型路径 |
| 103 | + std::string vision; // 视觉模型路径 |
| 104 | + std::string audio; // 音频模型路径 |
| 105 | + std::string tts; // TTS 模型路径 |
| 106 | + std::string projector; // Projector 模型路径 |
| 107 | + std::string base_dir; // 模型根目录 |
| 108 | +}; |
| 109 | + |
| 110 | +static std::string get_parent_dir(const std::string & path) { |
| 111 | + size_t last_slash = path.find_last_of("/\\"); |
| 112 | + if (last_slash != std::string::npos) { |
| 113 | + return path.substr(0, last_slash); |
| 114 | + } |
| 115 | + return "."; |
| 116 | +} |
| 117 | + |
| 118 | +static bool file_exists(const std::string & path) { |
| 119 | + FILE * f = fopen(path.c_str(), "rb"); |
| 120 | + if (f) { |
| 121 | + fclose(f); |
| 122 | + return true; |
| 123 | + } |
| 124 | + return false; |
| 125 | +} |
| 126 | + |
| 127 | +static OmniModelPaths resolve_model_paths(const std::string & llm_path) { |
| 128 | + OmniModelPaths paths; |
| 129 | + paths.llm = llm_path; |
| 130 | + paths.base_dir = get_parent_dir(llm_path); |
| 131 | + |
| 132 | + // 自动推断其他模型路径 |
| 133 | + paths.vision = paths.base_dir + "/vision/MiniCPM-o-4_5-vision-F16.gguf"; |
| 134 | + paths.audio = paths.base_dir + "/audio/MiniCPM-o-4_5-audio-F16.gguf"; |
| 135 | + paths.tts = paths.base_dir + "/tts/MiniCPM-o-4_5-tts-F16.gguf"; |
| 136 | + paths.projector = paths.base_dir + "/tts/MiniCPM-o-4_5-projector-F16.gguf"; |
| 137 | + |
| 138 | + return paths; |
| 139 | +} |
| 140 | + |
| 141 | +static void print_model_paths(const OmniModelPaths & paths) { |
| 142 | + printf("=== Model Paths ===\n"); |
| 143 | + printf(" Base dir: %s\n", paths.base_dir.c_str()); |
| 144 | + printf(" LLM: %s %s\n", paths.llm.c_str(), file_exists(paths.llm) ? "[OK]" : "[NOT FOUND]"); |
| 145 | + printf(" Vision: %s %s\n", paths.vision.c_str(), file_exists(paths.vision) ? "[OK]" : "[NOT FOUND]"); |
| 146 | + printf(" Audio: %s %s\n", paths.audio.c_str(), file_exists(paths.audio) ? "[OK]" : "[NOT FOUND]"); |
| 147 | + printf(" TTS: %s %s\n", paths.tts.c_str(), file_exists(paths.tts) ? "[OK]" : "[NOT FOUND]"); |
| 148 | + printf(" Projector: %s %s\n", paths.projector.c_str(), file_exists(paths.projector) ? "[OK]" : "[NOT FOUND]"); |
| 149 | + printf("===================\n"); |
| 150 | +} |
| 151 | + |
| 152 | +void test_case(struct omni_context *ctx_omni, common_params& params, std::string audio_path_prefix, int cnt){ |
| 153 | + for (int il = 0; il < cnt; ++il) { |
| 154 | + char idx_str[16]; |
| 155 | + snprintf(idx_str, sizeof(idx_str), "%04d", il); // 格式化为4位数字,如 0000, 0001 |
| 156 | + std::string aud_fname = audio_path_prefix + idx_str + ".wav"; |
| 157 | + |
| 158 | + auto t0 = std::chrono::high_resolution_clock::now(); |
| 159 | + stream_prefill(ctx_omni, aud_fname, "", il + 1); |
| 160 | + auto t1 = std::chrono::high_resolution_clock::now(); |
| 161 | + std::chrono::duration<double> elapsed_seconds = t1 - t0; |
| 162 | + double dt = elapsed_seconds.count(); |
| 163 | + std::cout << "prefill " << il << " : " << dt << " s"<< std::endl; |
| 164 | + } |
| 165 | + stream_decode(ctx_omni, "./"); |
| 166 | +} |
| 167 | + |
| 168 | +int main(int argc, char ** argv) { |
| 169 | + ggml_time_init(); |
| 170 | + |
| 171 | + // 命令行参数 |
| 172 | + std::string llm_path; |
| 173 | + std::string vision_path_override; |
| 174 | + std::string audio_path_override; |
| 175 | + std::string tts_path_override; |
| 176 | + std::string projector_path_override; |
| 177 | + std::string ref_audio_path = "tools/omni/assets/default_ref_audio.wav"; // 默认参考音频 |
| 178 | + int n_ctx = 4096; |
| 179 | + int n_gpu_layers = 99; // GPU 层数,默认 99 |
| 180 | + bool use_tts = true; |
| 181 | + bool run_test = false; |
| 182 | + std::string test_audio_prefix; |
| 183 | + int test_count = 0; |
| 184 | + |
| 185 | + // 解析命令行参数 |
| 186 | + for (int i = 1; i < argc; i++) { |
| 187 | + std::string arg = argv[i]; |
| 188 | + |
| 189 | + if (arg == "-h" || arg == "--help") { |
| 190 | + show_usage(argv[0]); |
| 191 | + return 0; |
| 192 | + } |
| 193 | + else if (arg == "-m" && i + 1 < argc) { |
| 194 | + llm_path = argv[++i]; |
| 195 | + } |
| 196 | + else if (arg == "--vision" && i + 1 < argc) { |
| 197 | + vision_path_override = argv[++i]; |
| 198 | + } |
| 199 | + else if (arg == "--audio" && i + 1 < argc) { |
| 200 | + audio_path_override = argv[++i]; |
| 201 | + } |
| 202 | + else if (arg == "--tts" && i + 1 < argc) { |
| 203 | + tts_path_override = argv[++i]; |
| 204 | + } |
| 205 | + else if (arg == "--projector" && i + 1 < argc) { |
| 206 | + projector_path_override = argv[++i]; |
| 207 | + } |
| 208 | + else if (arg == "--ref-audio" && i + 1 < argc) { |
| 209 | + ref_audio_path = argv[++i]; |
| 210 | + } |
| 211 | + else if ((arg == "-c" || arg == "--ctx-size") && i + 1 < argc) { |
| 212 | + n_ctx = std::atoi(argv[++i]); |
| 213 | + } |
| 214 | + else if (arg == "-ngl" && i + 1 < argc) { |
| 215 | + n_gpu_layers = std::atoi(argv[++i]); |
| 216 | + } |
| 217 | + else if (arg == "--no-tts") { |
| 218 | + use_tts = false; |
| 219 | + } |
| 220 | + else if (arg == "--test" && i + 2 < argc) { |
| 221 | + run_test = true; |
| 222 | + test_audio_prefix = argv[++i]; |
| 223 | + test_count = std::atoi(argv[++i]); |
| 224 | + } |
| 225 | + else { |
| 226 | + fprintf(stderr, "Unknown argument: %s\n", arg.c_str()); |
| 227 | + show_usage(argv[0]); |
| 228 | + return 1; |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + // 检查必需参数 |
| 233 | + if (llm_path.empty()) { |
| 234 | + fprintf(stderr, "Error: -m <llm_model_path> is required\n\n"); |
| 235 | + show_usage(argv[0]); |
| 236 | + return 1; |
| 237 | + } |
| 238 | + |
| 239 | + // 解析模型路径 |
| 240 | + OmniModelPaths paths = resolve_model_paths(llm_path); |
| 241 | + |
| 242 | + // 应用覆盖路径 |
| 243 | + if (!vision_path_override.empty()) paths.vision = vision_path_override; |
| 244 | + if (!audio_path_override.empty()) paths.audio = audio_path_override; |
| 245 | + if (!tts_path_override.empty()) paths.tts = tts_path_override; |
| 246 | + if (!projector_path_override.empty()) paths.projector = projector_path_override; |
| 247 | + |
| 248 | + // 打印模型路径 |
| 249 | + print_model_paths(paths); |
| 250 | + |
| 251 | + // 检查必需文件 |
| 252 | + if (!file_exists(paths.llm)) { |
| 253 | + fprintf(stderr, "Error: LLM model not found: %s\n", paths.llm.c_str()); |
| 254 | + return 1; |
| 255 | + } |
| 256 | + if (!file_exists(paths.audio)) { |
| 257 | + fprintf(stderr, "Error: Audio model not found: %s\n", paths.audio.c_str()); |
| 258 | + return 1; |
| 259 | + } |
| 260 | + if (use_tts && !file_exists(paths.tts)) { |
| 261 | + fprintf(stderr, "Warning: TTS model not found: %s, disabling TTS\n", paths.tts.c_str()); |
| 262 | + use_tts = false; |
| 263 | + } |
| 264 | + |
| 265 | + // 设置参数 |
| 266 | + common_params params; |
| 267 | + params.model.path = paths.llm; |
| 268 | + params.vpm_model = paths.vision; |
| 269 | + params.apm_model = paths.audio; |
| 270 | + params.tts_model = paths.tts; |
| 271 | + params.n_ctx = n_ctx; |
| 272 | + params.n_gpu_layers = n_gpu_layers; |
| 273 | + |
| 274 | + // Projector 路径需要通过 tts_bin_dir 传递 |
| 275 | + // omni.cpp 中 projector 路径计算: gguf_root_dir + "/projector.gguf" |
| 276 | + // 其中 gguf_root_dir = tts_bin_dir 的父目录 |
| 277 | + // 但我们的结构是 projector 在 tts/ 目录下 |
| 278 | + // 所以需要修改 omni.cpp 或者创建符号链接 |
| 279 | + // 这里暂时使用 tts 目录作为 tts_bin_dir |
| 280 | + std::string tts_bin_dir = get_parent_dir(paths.tts); |
| 281 | + |
| 282 | + common_init(); |
| 283 | + |
| 284 | + printf("=== Initializing Omni Context ===\n"); |
| 285 | + printf(" TTS enabled: %s\n", use_tts ? "yes" : "no"); |
| 286 | + printf(" Context size: %d\n", n_ctx); |
| 287 | + printf(" GPU layers: %d\n", n_gpu_layers); |
| 288 | + printf(" TTS bin dir: %s\n", tts_bin_dir.c_str()); |
| 289 | + printf(" Ref audio: %s\n", ref_audio_path.c_str()); |
| 290 | + |
| 291 | + auto ctx_omni = omni_init(¶ms, 1, use_tts, tts_bin_dir, -1, "gpu:0"); |
| 292 | + if (ctx_omni == nullptr) { |
| 293 | + fprintf(stderr, "Error: Failed to initialize omni context\n"); |
| 294 | + return 1; |
| 295 | + } |
| 296 | + ctx_omni->async = true; |
| 297 | + ctx_omni->ref_audio_path = ref_audio_path; // 设置参考音频路径 |
| 298 | + |
| 299 | + if (run_test) { |
| 300 | + printf("=== Running test case ===\n"); |
| 301 | + printf(" Audio prefix: %s\n", test_audio_prefix.c_str()); |
| 302 | + printf(" Count: %d\n", test_count); |
| 303 | + test_case(ctx_omni, params, test_audio_prefix, test_count); |
| 304 | + } else { |
| 305 | + // 默认测试用例 |
| 306 | + test_case(ctx_omni, params, std::string("tools/omni/assets/test_case/audio_test_case/audio_test_case_"), 2); |
| 307 | + } |
| 308 | + |
| 309 | + if(ctx_omni->async && ctx_omni->use_tts){ |
| 310 | + if(ctx_omni->tts_thread.joinable()) { |
| 311 | + ctx_omni->tts_thread.join(); |
| 312 | + printf("tts end\n"); |
| 313 | + } |
| 314 | + } |
| 315 | + |
| 316 | + llama_perf_context_print(ctx_omni->ctx_llama); |
| 317 | + |
| 318 | + omni_free(ctx_omni); |
| 319 | + return 0; |
| 320 | +} |
0 commit comments