1+ use tokenizers:: tokenizer:: Tokenizer ;
2+ use serde_json:: json;
3+ use wasmedge_wasi_nn:: {
4+ self , ExecutionTarget , GraphBuilder , GraphEncoding , GraphExecutionContext ,
5+ TensorType ,
6+ } ;
7+ use std:: env;
8+ fn get_data_from_context ( context : & GraphExecutionContext , index : usize ) -> Vec < u8 > {
9+ // Preserve for 4096 tokens with average token length 8
10+ const MAX_OUTPUT_BUFFER_SIZE : usize = 4096 * 8 ;
11+ let mut output_buffer = vec ! [ 0u8 ; MAX_OUTPUT_BUFFER_SIZE ] ;
12+ let _ = context
13+ . get_output ( index, & mut output_buffer)
14+ . expect ( "Failed to get output" ) ;
15+
16+ return output_buffer;
17+ }
18+
19+ fn get_output_from_context ( context : & GraphExecutionContext ) -> Vec < u8 > {
20+ get_data_from_context ( context, 0 )
21+ }
22+ fn main ( ) {
23+ let tokenizer_path = "tokenizer.json" ;
24+ let prompt = "Once upon a time, there existed a little girl," ;
25+
26+ let graph = GraphBuilder :: new ( GraphEncoding :: Mlx , ExecutionTarget :: AUTO )
27+ . config ( serde_json:: to_string ( & json ! ( { "tokenizer" : tokenizer_path} ) ) . expect ( "Failed to serialize options" ) )
28+ . build_from_cache ( model_name)
29+ . expect ( "Failed to build graph" ) ;
30+ let mut context = graph
31+ . init_execution_context ( )
32+ . expect ( "Failed to init context" ) ;
33+ let tensor_data = prompt. as_bytes ( ) . to_vec ( ) ;
34+ context
35+ . set_input ( 0 , TensorType :: U8 , & [ 1 ] , & tensor_data)
36+ . expect ( "Failed to set input" ) ;
37+ context. compute ( ) . expect ( "Failed to compute" ) ;
38+ let output_bytes = get_output_from_context ( & context) ;
39+
40+ println ! ( "{}" , output. trim( ) ) ;
41+
42+ }
0 commit comments