Skip to content

Commit c45afb5

Browse files
committed
feat: add Gemma 3 Vision-Language Model (VLM) support
- Add SigLIP Vision Encoder (27 transformer layers) - Add Multi-Modal Projector (AvgPool + Linear projection) - Add ImageProcessor for loading/resizing/normalizing images - Add Gemma3VLM model class combining vision + text - Extend CLI with --image flag for VLM generation - Add /image command for interactive VLM usage - Add isVLM() and generateWithImage() to Node.js API - Auto-detect VLM models via vision_config in config.json Supports Gemma 3 4B, 12B, 27B vision variants with: - 896x896 image input - 256 visual tokens per image - Streaming output for both text and VLM generation
1 parent 740dc9d commit c45afb5

File tree

10 files changed

+1506
-13
lines changed

10 files changed

+1506
-13
lines changed

packages/node-mlx/native/src/binding.cc

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ typedef bool (*IsAvailableFn)(void);
1414
typedef char* (*GetVersionFn)(void);
1515
typedef bool (*SetMetallibPathFn)(const char*);
1616
typedef char* (*GenerateStreamingFn)(int32_t, const char*, int32_t, float, float);
17+
typedef char* (*GenerateWithImageFn)(int32_t, const char*, const char*, int32_t, float, float);
18+
typedef bool (*IsVLMFn)(int32_t);
1719

1820
static LoadModelFn fn_load_model = nullptr;
1921
static UnloadModelFn fn_unload_model = nullptr;
2022
static GenerateFn fn_generate = nullptr;
2123
static GenerateStreamingFn fn_generate_streaming = nullptr;
24+
static GenerateWithImageFn fn_generate_with_image = nullptr;
25+
static IsVLMFn fn_is_vlm = nullptr;
2226
static FreeStringFn fn_free_string = nullptr;
2327
static IsAvailableFn fn_is_available = nullptr;
2428
static GetVersionFn fn_get_version = nullptr;
@@ -58,6 +62,8 @@ Napi::Value Initialize(const Napi::CallbackInfo& info) {
5862
fn_get_version = (GetVersionFn)dlsym(dylib_handle, "node_mlx_version");
5963
fn_set_metallib_path = (SetMetallibPathFn)dlsym(dylib_handle, "node_mlx_set_metallib_path");
6064
fn_generate_streaming = (GenerateStreamingFn)dlsym(dylib_handle, "node_mlx_generate_streaming");
65+
fn_generate_with_image = (GenerateWithImageFn)dlsym(dylib_handle, "node_mlx_generate_with_image");
66+
fn_is_vlm = (IsVLMFn)dlsym(dylib_handle, "node_mlx_is_vlm");
6167

6268
if (!fn_load_model || !fn_generate || !fn_free_string) {
6369
std::string missing;
@@ -236,6 +242,81 @@ Napi::Value GenerateStreaming(const Napi::CallbackInfo& info) {
236242
return Napi::String::New(env, jsonStr);
237243
}
238244

245+
// Generate text with image (VLM) - tokens are written directly to stdout
246+
Napi::Value GenerateWithImage(const Napi::CallbackInfo& info) {
247+
Napi::Env env = info.Env();
248+
249+
if (!fn_generate_with_image) {
250+
Napi::Error::New(env, "VLM generation not available").ThrowAsJavaScriptException();
251+
return env.Null();
252+
}
253+
254+
if (info.Length() < 3 || !info[0].IsNumber() || !info[1].IsString() || !info[2].IsString()) {
255+
Napi::TypeError::New(env, "Usage: generateWithImage(handle, prompt, imagePath, options?)").ThrowAsJavaScriptException();
256+
return env.Null();
257+
}
258+
259+
int32_t handle = info[0].As<Napi::Number>().Int32Value();
260+
std::string prompt = info[1].As<Napi::String>().Utf8Value();
261+
std::string imagePath = info[2].As<Napi::String>().Utf8Value();
262+
263+
// Default options
264+
int32_t maxTokens = 256;
265+
float temperature = 0.7f;
266+
float topP = 0.9f;
267+
268+
// Parse options object if provided
269+
if (info.Length() > 3 && info[3].IsObject()) {
270+
Napi::Object options = info[3].As<Napi::Object>();
271+
272+
if (options.Has("maxTokens")) {
273+
maxTokens = options.Get("maxTokens").As<Napi::Number>().Int32Value();
274+
}
275+
if (options.Has("temperature")) {
276+
temperature = options.Get("temperature").As<Napi::Number>().FloatValue();
277+
}
278+
if (options.Has("topP")) {
279+
topP = options.Get("topP").As<Napi::Number>().FloatValue();
280+
}
281+
}
282+
283+
// Flush stdout before calling streaming generate
284+
fflush(stdout);
285+
286+
char* jsonResult = fn_generate_with_image(handle, prompt.c_str(), imagePath.c_str(), maxTokens, temperature, topP);
287+
288+
// Flush again after generation
289+
fflush(stdout);
290+
291+
if (!jsonResult) {
292+
Napi::Error::New(env, "Generate with image returned null").ThrowAsJavaScriptException();
293+
return env.Null();
294+
}
295+
296+
std::string jsonStr(jsonResult);
297+
fn_free_string(jsonResult);
298+
299+
// Return the JSON string with stats
300+
return Napi::String::New(env, jsonStr);
301+
}
302+
303+
// Check if model is a VLM (Vision-Language Model)
304+
Napi::Value IsVLM(const Napi::CallbackInfo& info) {
305+
Napi::Env env = info.Env();
306+
307+
if (!fn_is_vlm) {
308+
return Napi::Boolean::New(env, false);
309+
}
310+
311+
if (info.Length() < 1 || !info[0].IsNumber()) {
312+
Napi::TypeError::New(env, "Model handle number required").ThrowAsJavaScriptException();
313+
return Napi::Boolean::New(env, false);
314+
}
315+
316+
int32_t handle = info[0].As<Napi::Number>().Int32Value();
317+
return Napi::Boolean::New(env, fn_is_vlm(handle));
318+
}
319+
239320
// Check if MLX is available
240321
Napi::Value IsAvailable(const Napi::CallbackInfo& info) {
241322
Napi::Env env = info.Env();
@@ -282,6 +363,8 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
282363
exports.Set("unloadModel", Napi::Function::New(env, UnloadModel));
283364
exports.Set("generate", Napi::Function::New(env, Generate));
284365
exports.Set("generateStreaming", Napi::Function::New(env, GenerateStreaming));
366+
exports.Set("generateWithImage", Napi::Function::New(env, GenerateWithImage));
367+
exports.Set("isVLM", Napi::Function::New(env, IsVLM));
285368
exports.Set("isAvailable", Napi::Function::New(env, IsAvailable));
286369
exports.Set("getVersion", Napi::Function::New(env, GetVersion));
287370

packages/node-mlx/src/cli.ts

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,16 @@ function printHelp() {
5454
log(` mlx Interactive chat`)
5555
log(` mlx "prompt" One-shot generation`)
5656
log(` mlx --model <name> Use specific model`)
57+
log(` mlx --image <path> Include image (VLM only)`)
5758
log(` mlx --list List available models`)
5859
log(` mlx --help Show this help`)
5960
log("")
61+
log(`${colors.bold}Vision models (VLM):${colors.reset}`)
62+
log(` mlx --model gemma-3-4b --image photo.jpg "What's in this image?"`)
63+
log("")
6064
log(`${colors.bold}Interactive commands:${colors.reset}`)
6165
log(` /model <name> Switch model`)
66+
log(` /image <path> Set image for next prompt`)
6267
log(` /temp <0-2> Set temperature`)
6368
log(` /tokens <n> Set max tokens`)
6469
log(` /clear Clear conversation`)
@@ -167,6 +172,7 @@ interface ChatState {
167172
modelName: string
168173
options: GenerationOptions
169174
history: Array<{ role: "user" | "assistant"; content: string }>
175+
imagePath: string | null // For VLM image input
170176
}
171177

172178
async function runInteractive(initialModel: string) {
@@ -178,7 +184,8 @@ async function runInteractive(initialModel: string) {
178184
temperature: 0.7,
179185
topP: 0.9
180186
},
181-
history: []
187+
history: [],
188+
imagePath: null
182189
}
183190

184191
// Load initial model
@@ -235,8 +242,16 @@ async function runInteractive(initialModel: string) {
235242
process.stdout.write(`${colors.magenta}AI:${colors.reset} `)
236243

237244
try {
238-
// Use streaming - tokens are written directly to stdout
239-
const result = state.model.generateStreaming(fullPrompt, state.options)
245+
let result
246+
247+
// Check if we have an image to send
248+
if (state.imagePath && state.model.isVLM()) {
249+
result = state.model.generateWithImage(fullPrompt, state.imagePath, state.options)
250+
state.imagePath = null // Clear after use
251+
} else {
252+
// Use streaming - tokens are written directly to stdout
253+
result = state.model.generateStreaming(fullPrompt, state.options)
254+
}
240255

241256
// Note: text already streamed, we only have stats
242257
log("")
@@ -374,21 +389,61 @@ async function handleCommand(input: string, state: ChatState, rl: readline.Inter
374389
printModels()
375390
break
376391

392+
case "image":
393+
case "i":
394+
if (!arg) {
395+
if (state.imagePath) {
396+
log(`${colors.dim}Current image: ${state.imagePath}${colors.reset}`)
397+
} else {
398+
log(`${colors.dim}No image set. Use /image <path> to set one.${colors.reset}`)
399+
}
400+
} else {
401+
// Check if file exists
402+
const fs = await import("node:fs")
403+
if (!fs.existsSync(arg)) {
404+
error(`Image not found: ${arg}`)
405+
} else if (!state.model?.isVLM()) {
406+
error(`Current model doesn't support images. Use a VLM like gemma-3-4b.`)
407+
} else {
408+
state.imagePath = arg
409+
log(`${colors.green}${colors.reset} Image set: ${arg}`)
410+
log(`${colors.dim}The next prompt will include this image.${colors.reset}`)
411+
}
412+
}
413+
break
414+
377415
default:
378416
error(`Unknown command: /${cmd}. Type /help for commands.`)
379417
}
380418
}
381419

382-
async function runOneShot(modelName: string, prompt: string, options: GenerationOptions) {
420+
async function runOneShot(
421+
modelName: string,
422+
prompt: string,
423+
imagePath: string | null,
424+
options: GenerationOptions
425+
) {
383426
log(`${colors.dim}Loading ${modelName}...${colors.reset}`)
384427

385428
const modelId = resolveModel(modelName)
386429

387430
try {
388431
const model = loadModel(modelId)
389432

390-
// Use streaming - tokens are written directly to stdout
391-
const result = model.generateStreaming(prompt, options)
433+
let result
434+
435+
// Check if we have an image to process
436+
if (imagePath) {
437+
if (!model.isVLM()) {
438+
error(`Model ${modelName} doesn't support images. Use a VLM like gemma-3-4b.`)
439+
model.unload()
440+
process.exit(1)
441+
}
442+
result = model.generateWithImage(prompt, imagePath, options)
443+
} else {
444+
// Use streaming - tokens are written directly to stdout
445+
result = model.generateStreaming(prompt, options)
446+
}
392447

393448
// Add newline after streamed output
394449
log("")
@@ -407,12 +462,14 @@ async function runOneShot(modelName: string, prompt: string, options: Generation
407462
function parseArgs(): {
408463
model: string
409464
prompt: string | null
465+
imagePath: string | null
410466
options: GenerationOptions
411467
command: "chat" | "oneshot" | "list" | "help" | "version"
412468
} {
413469
const args = process.argv.slice(2)
414470
let model = "qwen" // Default to Qwen (no auth required)
415471
let prompt: string | null = null
472+
let imagePath: string | null = null
416473
const options: GenerationOptions = {
417474
maxTokens: 512,
418475
temperature: 0.7,
@@ -431,6 +488,8 @@ function parseArgs(): {
431488
command = "list"
432489
} else if (arg === "--model" || arg === "-m") {
433490
model = args[++i] || model
491+
} else if (arg === "--image" || arg === "-i") {
492+
imagePath = args[++i] || null
434493
} else if (arg === "--temp" || arg === "-t") {
435494
options.temperature = parseFloat(args[++i] || "0.7")
436495
} else if (arg === "--tokens" || arg === "-n") {
@@ -446,12 +505,12 @@ function parseArgs(): {
446505
}
447506
}
448507

449-
return { model, prompt, options, command }
508+
return { model, prompt, imagePath, options, command }
450509
}
451510

452511
// Main
453512
async function main() {
454-
const { model, prompt, options, command } = parseArgs()
513+
const { model, prompt, imagePath, options, command } = parseArgs()
455514

456515
// Commands that don't need Apple Silicon
457516
switch (command) {
@@ -486,7 +545,7 @@ async function main() {
486545

487546
switch (command) {
488547
case "oneshot":
489-
await runOneShot(model, prompt!, options)
548+
await runOneShot(model, prompt!, imagePath, options)
490549
break
491550

492551
case "chat":

packages/node-mlx/src/index.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ interface NativeBinding {
2424
prompt: string,
2525
options?: { maxTokens?: number; temperature?: number; topP?: number }
2626
): string // Streams to stdout, returns JSON stats
27+
generateWithImage(
28+
handle: number,
29+
prompt: string,
30+
imagePath: string,
31+
options?: { maxTokens?: number; temperature?: number; topP?: number }
32+
): string // VLM: Streams to stdout, returns JSON stats
33+
isVLM(handle: number): boolean
2734
isAvailable(): boolean
2835
getVersion(): string
2936
}
@@ -155,6 +162,12 @@ export interface Model {
155162
/** Generate text with streaming - tokens are written directly to stdout */
156163
generateStreaming(prompt: string, options?: GenerationOptions): StreamingResult
157164

165+
/** Generate text from a prompt with an image (VLM only) */
166+
generateWithImage(prompt: string, imagePath: string, options?: GenerationOptions): StreamingResult
167+
168+
/** Check if this model supports images (is a Vision-Language Model) */
169+
isVLM(): boolean
170+
158171
/** Unload the model from memory */
159172
unload(): void
160173

@@ -302,6 +315,34 @@ export function loadModel(modelId: string): Model {
302315
}
303316
},
304317

318+
generateWithImage(
319+
prompt: string,
320+
imagePath: string,
321+
options?: GenerationOptions
322+
): StreamingResult {
323+
// VLM generation with image - tokens are written directly to stdout by Swift
324+
const jsonStr = b.generateWithImage(handle, prompt, imagePath, {
325+
maxTokens: options?.maxTokens ?? 256,
326+
temperature: options?.temperature ?? 0.7,
327+
topP: options?.topP ?? 0.9
328+
})
329+
330+
const result = JSON.parse(jsonStr) as JSONGenerationResult
331+
332+
if (!result.success) {
333+
throw new Error(result.error ?? "Generation failed")
334+
}
335+
336+
return {
337+
tokenCount: result.tokenCount ?? 0,
338+
tokensPerSecond: result.tokensPerSecond ?? 0
339+
}
340+
},
341+
342+
isVLM(): boolean {
343+
return b.isVLM(handle)
344+
},
345+
305346
unload(): void {
306347
b.unloadModel(handle)
307348
}

0 commit comments

Comments
 (0)