Skip to content

Commit 7a323a7

Browse files
grorge123hydai
authored andcommitted
[Example] MLX: add mlx example
1 parent f1496d6 commit 7a323a7

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

wasmedge-mlx/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[package]
2+
name = "wasmedge-mlx"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = "0.8.0"

wasmedge-mlx/src/main.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use tokenizers::tokenizer::Tokenizer;
2+
use serde_json::json;
3+
use wasmedge_wasi_nn::{
4+
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
5+
TensorType,
6+
};
7+
use std::env;
8+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> Vec<u8> {
9+
// Preserve for 4096 tokens with average token length 8
10+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 8;
11+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
12+
let _ = context
13+
.get_output(index, &mut output_buffer)
14+
.expect("Failed to get output");
15+
16+
return output_buffer;
17+
}
18+
19+
fn get_output_from_context(context: &GraphExecutionContext) -> Vec<u8> {
20+
get_data_from_context(context, 0)
21+
}
22+
fn main() {
23+
let tokenizer_path = "tokenizer.json";
24+
let prompt = "Once upon a time, there existed a little girl,";
25+
26+
let graph = GraphBuilder::new(GraphEncoding::Mlx, ExecutionTarget::AUTO)
27+
.config(serde_json::to_string(&json!({"tokenizer":tokenizer_path})).expect("Failed to serialize options"))
28+
.build_from_cache(model_name)
29+
.expect("Failed to build graph");
30+
let mut context = graph
31+
.init_execution_context()
32+
.expect("Failed to init context");
33+
let tensor_data = prompt.as_bytes().to_vec();
34+
context
35+
.set_input(0, TensorType::U8, &[1], &tensor_data)
36+
.expect("Failed to set input");
37+
context.compute().expect("Failed to compute");
38+
let output_bytes = get_output_from_context(&context);
39+
40+
println!("{}", output.trim());
41+
42+
}

0 commit comments

Comments
 (0)