1- use serde_json:: Value ;
2- use std:: collections:: HashMap ;
1+ use serde_json:: { json, Value } ;
32use std:: env;
43use std:: io;
54use wasmedge_wasi_nn:: {
@@ -19,6 +18,27 @@ fn read_input() -> String {
1918 }
2019}
2120
21+ fn get_options_from_env ( ) -> Value {
22+ let mut options = json ! ( { } ) ;
23+ if let Ok ( val) = env:: var ( "enable_log" ) {
24+ options[ "enable-log" ] = serde_json:: from_str ( val. as_str ( ) )
25+ . expect ( "invalid value for enable-log option (true/false)" )
26+ }
27+ if let Ok ( val) = env:: var ( "n_gpu_layers" ) {
28+ options[ "n-gpu-layers" ] =
29+ serde_json:: from_str ( val. as_str ( ) ) . expect ( "invalid ngl value (unsigned integer" )
30+ }
31+ if let Ok ( val) = env:: var ( "ctx_size" ) {
32+ options[ "ctx-size" ] =
33+ serde_json:: from_str ( val. as_str ( ) ) . expect ( "invalid ctx-size value (unsigned integer" )
34+ }
35+ if let Ok ( val) = env:: var ( "reverse_prompt" ) {
36+ options[ "reverse-prompt" ] = json ! ( val. as_str( ) )
37+ }
38+
39+ options
40+ }
41+
2242fn set_data_to_context ( context : & mut GraphExecutionContext , data : Vec < u8 > ) -> Result < ( ) , Error > {
2343 context. set_input ( 0 , TensorType :: U8 , & [ 1 ] , & data)
2444}
@@ -56,11 +76,9 @@ fn main() {
5676 let args: Vec < String > = env:: args ( ) . collect ( ) ;
5777 let model_name: & str = & args[ 1 ] ;
5878
59- // Set options for the graph. Check our README for more details.
60- let mut options = HashMap :: new ( ) ;
61- options. insert ( "enable-log" , Value :: from ( false ) ) ;
62- options. insert ( "n-gpu-layers" , Value :: from ( 0 ) ) ;
63- options. insert ( "ctx-size" , Value :: from ( 512 ) ) ;
79+ // Set options for the graph. Check our README for more details:
80+ // https://github.com/second-state/WasmEdge-WASINN-examples/tree/master/wasmedge-ggml#parameters
81+ let options = get_options_from_env ( ) ;
6482
6583 // Create graph and initialize context.
6684 let graph = GraphBuilder :: new ( GraphEncoding :: Ggml , ExecutionTarget :: AUTO )
@@ -82,6 +100,48 @@ fn main() {
82100 // )
83101 // .expect("Failed to set metadata");
84102
103+ // If there is a third argument, use it as the prompt and enter non-interactive mode.
104+ // This is mainly for the CI workflow.
105+ if args. len ( ) >= 3 {
106+ let prompt = & args[ 2 ] ;
107+ // Set the prompt.
108+ println ! ( "Prompt:\n {}" , prompt) ;
109+ let tensor_data = prompt. as_bytes ( ) . to_vec ( ) ;
110+ context
111+ . set_input ( 0 , TensorType :: U8 , & [ 1 ] , & tensor_data)
112+ . expect ( "Failed to set input" ) ;
113+ println ! ( "Response:" ) ;
114+
115+ // Get the number of input tokens and llama.cpp versions.
116+ let input_metadata = get_metadata_from_context ( & context) ;
117+ println ! ( "[INFO] llama_commit: {}" , input_metadata[ "llama_commit" ] ) ;
118+ println ! (
119+ "[INFO] llama_build_number: {}" ,
120+ input_metadata[ "llama_build_number" ]
121+ ) ;
122+ println ! (
123+ "[INFO] Number of input tokens: {}" ,
124+ input_metadata[ "input_tokens" ]
125+ ) ;
126+
127+ // Get the output.
128+ context. compute ( ) . expect ( "Failed to compute" ) ;
129+ let output = get_output_from_context ( & context) ;
130+ println ! ( "{}" , output. trim( ) ) ;
131+
132+ // Retrieve the output metadata.
133+ let metadata = get_metadata_from_context ( & context) ;
134+ println ! (
135+ "[INFO] Number of input tokens: {}" ,
136+ metadata[ "input_tokens" ]
137+ ) ;
138+ println ! (
139+ "[INFO] Number of output tokens: {}" ,
140+ metadata[ "output_tokens" ]
141+ ) ;
142+ std:: process:: exit ( 0 ) ;
143+ }
144+
85145 let mut saved_prompt = String :: new ( ) ;
86146 let system_prompt = String :: from ( "You are a helpful, respectful and honest assistant. Always answer as short as possible, while being safe." ) ;
87147
@@ -101,18 +161,6 @@ fn main() {
101161 set_data_to_context ( & mut context, saved_prompt. as_bytes ( ) . to_vec ( ) )
102162 . expect ( "Failed to set input" ) ;
103163
104- // Get the number of input tokens and llama.cpp versions.
105- // let input_metadata = get_metadata_from_context(&context);
106- // println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
107- // println!(
108- // "[INFO] llama_build_number: {}",
109- // input_metadata["llama_build_number"]
110- // );
111- // println!(
112- // "[INFO] Number of input tokens: {}",
113- // input_metadata["input_tokens"]
114- // );
115-
116164 // Execute the inference.
117165 let mut reset_prompt = false ;
118166 match context. compute ( ) {
@@ -141,16 +189,5 @@ fn main() {
141189 output = output. trim ( ) . to_string ( ) ;
142190 saved_prompt = format ! ( "{}{}<|im_end|>\n " , saved_prompt, output) ;
143191 }
144-
145- // Retrieve the output metadata.
146- // let metadata = get_metadata_from_context(&context);
147- // println!(
148- // "[INFO] Number of input tokens: {}",
149- // metadata["input_tokens"]
150- // );
151- // println!(
152- // "[INFO] Number of output tokens: {}",
153- // metadata["output_tokens"]
154- // );
155192 }
156193}
0 commit comments