Skip to content

Commit 1426b10

Browse files
authored
feat: add processor example (#194)
* feat: add processor example * refactor: Erase redundant code
1 parent c5ddda3 commit 1426b10

File tree

3 files changed

+76
-35
lines changed

3 files changed

+76
-35
lines changed

wasmedge-mlx/vlm/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
[package]
22
name = "wasmedge-vlm"
33
version = "0.1.0"
4-
edition = "2024"
4+
edition = "2021"
55

66
[dependencies]
77
serde_json = "1.0"
88
wasmedge-wasi-nn = "0.8.0"
9+
rust_processor = { git = "https://github.com/second-state/wasi_processor", subdirectory = "processor", branch = "main" }

wasmedge-mlx/vlm/README.md

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,22 @@ cmake --install build
2626

2727
Then you will have an executable `wasmedge` runtime under `/usr/local/bin` and the WASI-NN with MLX backend plug-in under `/usr/local/lib/wasmedge/libwasmedgePluginWasiNN.so` after installation.
2828

29-
## Install dependencies
30-
31-
Currently, we use the Python transformer library to embed the prompt and image to input the token. You can use any other library instead of this step.
32-
33-
``` bash
34-
sudo apt install python3 python3-pip
35-
pip install transformers pillow mlx
36-
```
37-
3829
## Download the model and tokenizer
3930

40-
In this example, we will use `gemma-3-4b-pt-bf16`.
31+
In this example, we will use `gemma-3-4b-it-4bit`.
4132

4233
``` bash
43-
git clone https://huggingface.co/mlx-community/gemma-3-4b-pt-bf16
34+
git clone https://huggingface.co/mlx-community/gemma-3-4b-it-4bit
4435
```
4536

4637
## Build wasm
4738

48-
Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasip1/release/`
39+
Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasip1/release/`.
40+
Then we use AOT-compiled WASM to improve the performance.
4941

5042
```bash
5143
cargo build --target wasm32-wasip1 --release
44+
wasmedge compiler ./target/wasm32-wasip1/release/wasmedge-vlm.wasm wasmedge-vlm_aot.wasm
5245
```
5346
## Execute
5447

@@ -58,25 +51,20 @@ Execute the WASM with the `wasmedge` using nn-preload to load model.
5851
# Download sample image
5952
wget https://github.com/WasmEdge/WasmEdge/raw/master/docs/wasmedge-runtime-logo.png
6053

61-
# python encode.py <model_path> <image_path> <prompt>
62-
python encode.py gemma-3-4b-it-bf16 wasmedge-runtime-logo.png "What is this icon?"
6354

6455
wasmedge --dir .:. \
65-
--nn-preload default:mlx:AUTO:model.safetensors \
66-
./target/wasm32-wasip1/release/wasmedge-vlm.wasm default
67-
68-
# python encode.py <model_path> <Output mlx array path>
69-
python decode.py gemma-3-4b-it-bf16 Answer.npy
56+
--nn-preload default:mlx:AUTO:gemma-3-4b-it-4bit/model.safetensors \
57+
./wasmedge-vlm_aot.wasm default gemma-3-4b-it-4bit
7058

7159
```
7260

7361
If your model has multiple weight files, you need to provide all in the nn-preload.
7462

7563
For example:
7664
``` bash
77-
wasmedge --dir .:. \
78-
--nn-preload default:mlx:AUTO:gemma-3-4b-it-bf16/model-00001-of-00002.safetensors:gemma-3-4b-it-bf16/model-00002-of-00002.safetensors \
79-
./target/wasm32-wasip1/release/wasmedge-vlm.wasm default
65+
wasmedge --dir .:. \
66+
--nn-preload default:mlx:AUTO:gemma-3-4b-it-4bit/model-00001-of-00002.safetensors:gemma-3-4b-it-4bit/model-00002-of-00002.safetensors \
67+
./target/wasm32-wasip1/release/wasmedge-vlm.wasm default
8068
```
8169

8270
## Other

wasmedge-mlx/vlm/src/main.rs

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1-
use serde_json::json;
1+
use rust_processor::auto::processing_auto::AutoProcessor;
2+
use rust_processor::gemma3::detokenizer::decode;
3+
use rust_processor::processor_utils::prepare_inputs;
4+
use rust_processor::NDTensorI32;
25
use std::env;
36
use wasmedge_wasi_nn::{
47
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, TensorType,
58
};
69

7-
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
10+
use serde_json::Value;
11+
use std::fs::File;
12+
use std::io::{self, BufReader};
13+
14+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> NDTensorI32 {
815
// Preserve for 4096 tokens with average token length 6
916
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
1017
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
@@ -13,44 +20,89 @@ fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> Strin
1320
.expect("Failed to get output");
1421
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
1522

16-
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
23+
return NDTensorI32::from_bytes(&output_buffer[..output_size]).unwrap();
1724
}
1825

19-
fn get_output_from_context(context: &GraphExecutionContext) -> String {
26+
fn get_output_from_context(context: &GraphExecutionContext) -> NDTensorI32 {
2027
get_data_from_context(context, 0)
2128
}
2229

30+
fn read_json(path: &str) -> io::Result<Value> {
31+
let file = File::open(path)?;
32+
let reader = BufReader::new(file);
33+
let v = serde_json::from_reader(reader)
34+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
35+
Ok(v)
36+
}
37+
2338
fn main() {
2439
// prompt: "What is this icon?";
2540
// image: "wasmedge-runtime-logo.png";
2641
let args: Vec<String> = env::args().collect();
2742
let model_name: &str = &args[1];
43+
let model_dir = &args[2];
44+
let config = read_json(&format!("{}/config.json", model_dir)).unwrap();
45+
let prompt = "<bos><start_of_turn>user\
46+
What is this icon?<start_of_image><end_of_turn>\
47+
<start_of_turn>model";
48+
let image_path = "wasmedge-runtime-logo.png";
49+
println!("create processor: {}", model_dir);
50+
let mut processor = match AutoProcessor::from_pretrained(model_dir) {
51+
Ok(processor) => match processor {
52+
rust_processor::auto::processing_auto::AutoProcessorType::Gemma3(processor) => {
53+
processor
54+
}
55+
_ => {
56+
eprintln!("Error loading processor: not a Gemma3Processor");
57+
return;
58+
}
59+
},
60+
Err(e) => {
61+
eprintln!("Error loading processor: {}", e);
62+
return;
63+
}
64+
};
65+
println!("processor created");
66+
let image_token_index = config["image_token_index"].as_u64().unwrap_or(262144) as u32;
67+
let model_inputs = prepare_inputs(
68+
&mut processor,
69+
&[image_path], // Use single image array
70+
prompt,
71+
image_token_index,
72+
Some((896, 896)), // Use 896x896 as image size
73+
);
2874
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
29-
.config(
30-
serde_json::to_string(&json!({"model_type": "gemma3", "max_token":250}))
31-
.expect("Failed to serialize options"),
32-
)
75+
.config(config.to_string())
3376
.build_from_cache(model_name)
3477
.expect("Failed to build graph");
3578

3679
let mut context = graph
3780
.init_execution_context()
3881
.expect("Failed to init context");
3982

40-
let tensor_data = "input_ids.npy".as_bytes().to_vec();
83+
let tensor_data = model_inputs["input_ids"].to_bytes();
4184
context
4285
.set_input(0, TensorType::U8, &[1], &tensor_data)
4386
.expect("Failed to set input");
44-
let tensor_data = "pixel_values.npy".as_bytes().to_vec();
87+
let tensor_data = model_inputs["pixel_values"].to_bytes();
4588
context
4689
.set_input(1, TensorType::U8, &[1], &tensor_data)
4790
.expect("Failed to set input");
48-
let tensor_data = "mask.npy".as_bytes().to_vec();
91+
let tensor_data = model_inputs["mask"].to_bytes();
4992
context
5093
.set_input(2, TensorType::U8, &[1], &tensor_data)
5194
.expect("Failed to set input");
5295

5396
context.compute().expect("Failed to compute");
54-
let output = get_output_from_context(&context);
97+
let tokens = get_output_from_context(&context);
98+
let output = decode(
99+
&tokens
100+
.data
101+
.into_iter()
102+
.map(|x| x as usize)
103+
.collect::<Vec<_>>(),
104+
&processor,
105+
true,
106+
);
55107
println!("{}", output.trim());
56108
}

0 commit comments

Comments
 (0)