Skip to content

Commit 7ee0774

Browse files
grorge123hydai
authored andcommitted
[Example] MLX: add documentation
1 parent 7a323a7 commit 7ee0774

File tree

2 files changed

+107
-27
lines changed

2 files changed

+107
-27
lines changed

wasmedge-mlx/README.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# MLX example with WasmEdge WASI-NN MLX plugin
2+
3+
This example demonstrates using WasmEdge WASI-NN MLX plugin to perform an inference task with LLM model.
4+
5+
## Supported Models
6+
7+
| Family | Models |
8+
|--------|--------|
9+
| LLaMA 2 | llama_2_7b_chat_hf |
10+
| LLaMA 3 | llama_3_8b |
11+
| TinyLLaMA | tiny_llama_1.1B_chat_v1.0 |
12+
13+
## Install WasmEdge with WASI-NN MLX plugin
14+
15+
The MLX backend relies on [MLX](https://github.com/ml-explore/mlx), but we will auto-download MLX when you build WasmEdge. You do not need to install it yourself. If you want to custom MLX, install it yourself or set the `CMAKE_PREFIX_PATH` variable when configuring cmake.
16+
17+
Build and install WasmEdge from source:
18+
19+
``` bash
20+
cd <path/to/your/wasmedge/source/folder>
21+
22+
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="mlx"
23+
cmake --build build
24+
25+
# For the WASI-NN plugin, you should install this project.
26+
cmake --install build
27+
```
28+
29+
Then you will have an executable `wasmedge` runtime under `/usr/local/bin` and the WASI-NN with MLX backend plug-in under `/usr/local/lib/wasmedge/libwasmedgePluginWasiNN.so` after installation.
30+
31+
## Download the model and tokenizer
32+
33+
In this example, we will use `tiny_llama_1.1B_chat_v1.0`, which you can change to `llama_2_7b_chat_hf` or `llama_3_8b`.
34+
35+
``` bash
36+
# Download model weight
37+
wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/model.safetensors
38+
# Download tokenizer
39+
wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json
40+
```
41+
42+
## Build wasm
43+
44+
Run the following command to build wasm, the output WASM file will be at `target/wasm32-wasi/release/`
45+
46+
```bash
47+
cargo build --target wasm32-wasi --release
48+
```
49+
## Execute
50+
51+
Execute the WASM with the `wasmedge` using nn-preload to load model.
52+
53+
``` bash
54+
wasmedge --dir .:. \
55+
--nn-preload default:mlx:AUTO:model.safetensors \
56+
./target/wasm32-wasi/release/wasmedge-mlx.wasm default
57+
58+
```
59+
60+
If your model has multiple weight files, you need to provide all in the nn-preload.
61+
62+
For example:
63+
``` bash
64+
wasmedge --dir .:. \
65+
--nn-preload default:mlx:AUTO:llama2-7b/model-00001-of-00002.safetensors:llama2-7b/model-00002-of-00002.safetensors \
66+
./target/wasm32-wasi/release/wasmedge-mlx.wasm default
67+
```
68+
69+
## Other
70+
71+
There are some metadata for MLX plugin you can set.
72+
73+
- model_type (required): LLM model type.
74+
- tokenizer (required): tokenizer.json path
75+
- max_token (option): maximum generate token number, default is 1024.
76+
- enable_debug_log (option): if print debug log, default is false.
77+
78+
``` rust
79+
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
80+
.config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100}))
81+
```

wasmedge-mlx/src/main.rs

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,41 @@
1-
use tokenizers::tokenizer::Tokenizer;
21
use serde_json::json;
2+
use std::env;
33
use wasmedge_wasi_nn::{
4-
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
5-
TensorType,
4+
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext, TensorType,
65
};
7-
use std::env;
8-
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> Vec<u8> {
9-
// Preserve for 4096 tokens with average token length 8
10-
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 8;
6+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
7+
// Preserve for 4096 tokens with average token length 6
8+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
119
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
12-
let _ = context
10+
let mut output_size = context
1311
.get_output(index, &mut output_buffer)
1412
.expect("Failed to get output");
13+
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
1514

16-
return output_buffer;
15+
return String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
1716
}
1817

19-
fn get_output_from_context(context: &GraphExecutionContext) -> Vec<u8> {
18+
fn get_output_from_context(context: &GraphExecutionContext) -> String {
2019
get_data_from_context(context, 0)
2120
}
2221
fn main() {
23-
let tokenizer_path = "tokenizer.json";
24-
let prompt = "Once upon a time, there existed a little girl,";
25-
26-
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
27-
.config(serde_json::to_string(&json!({"tokenizer":tokenizer_path})).expect("Failed to serialize options"))
22+
let tokenizer_path = "./tokenizer.json";
23+
let prompt = "Once upon a time, there existed a little girl,";
24+
let args: Vec<String> = env::args().collect();
25+
let model_name: &str = &args[1];
26+
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
27+
.config(serde_json::to_string(&json!({"model_type": "tiny_llama_1.1B_chat_v1.0", "tokenizer":tokenizer_path, "max_token":100})).expect("Failed to serialize options"))
2828
.build_from_cache(model_name)
2929
.expect("Failed to build graph");
30-
let mut context = graph
31-
.init_execution_context()
32-
.expect("Failed to init context");
33-
let tensor_data = prompt.as_bytes().to_vec();
34-
context
35-
.set_input(0, TensorType::U8, &[1], &tensor_data)
36-
.expect("Failed to set input");
37-
context.compute().expect("Failed to compute");
38-
let output_bytes = get_output_from_context(&context);
30+
let mut context = graph
31+
.init_execution_context()
32+
.expect("Failed to init context");
33+
let tensor_data = prompt.as_bytes().to_vec();
34+
context
35+
.set_input(0, TensorType::U8, &[1], &tensor_data)
36+
.expect("Failed to set input");
37+
context.compute().expect("Failed to compute");
38+
let output = get_output_from_context(&context);
3939

40-
println!("{}", output.trim());
41-
42-
}
40+
println!("{}", output.trim());
41+
}

0 commit comments

Comments
 (0)