11use serde_json:: Value ;
22use std:: collections:: HashMap ;
33use std:: env;
4- use std:: io;
4+ use std:: fs:: File ;
5+ use std:: io:: { self , Write } ;
56use wasmedge_wasi_nn:: {
67 self , BackendError , Error , ExecutionTarget , GraphBuilder , GraphEncoding , GraphExecutionContext ,
78 TensorType ,
89} ;
910
10- fn read_input ( ) -> String {
11- loop {
12- let mut answer = String :: new ( ) ;
13- io:: stdin ( )
14- . read_line ( & mut answer)
15- . expect ( "Failed to read line" ) ;
16- if !answer. is_empty ( ) && answer != "\n " && answer != "\r \n " {
17- return answer. trim ( ) . to_string ( ) ;
18- }
19- }
20- }
21-
2211fn get_options_from_env ( ) -> HashMap < & ' static str , Value > {
2312 let mut options = HashMap :: new ( ) ;
2413
@@ -36,6 +25,10 @@ fn get_options_from_env() -> HashMap<&'static str, Value> {
3625 eprintln ! ( "Failed to get vocoder model." ) ;
3726 std:: process:: exit ( 1 ) ;
3827 }
28+ // Speaker profile is optional.
29+ if let Ok ( val) = env:: var ( "tts_speaker_file" ) {
30+ options. insert ( "tts-speaker-file" , Value :: from ( val. as_str ( ) ) ) ;
31+ }
3932
4033 // Optional parameters
4134 if let Ok ( val) = env:: var ( "enable_log" ) {
@@ -79,31 +72,39 @@ fn get_options_from_env() -> HashMap<&'static str, Value> {
7972 if let Ok ( val) = env:: var ( "seed" ) {
8073 options. insert ( "seed" , serde_json:: from_str ( val. as_str ( ) ) . unwrap ( ) ) ;
8174 }
75+ if let Ok ( val) = env:: var ( "temp" ) {
76+ options. insert ( "temp" , serde_json:: from_str ( val. as_str ( ) ) . unwrap ( ) ) ;
77+ }
8278 options
8379}
8480
8581fn set_data_to_context ( context : & mut GraphExecutionContext , data : Vec < u8 > ) -> Result < ( ) , Error > {
8682 context. set_input ( 0 , TensorType :: U8 , & [ 1 ] , & data)
8783}
8884
89- fn get_data_from_context ( context : & GraphExecutionContext , index : usize ) -> String {
90- // Preserve for 4096 tokens with average token length 6
91- const MAX_OUTPUT_BUFFER_SIZE : usize = 4096 * 6 ;
85+ fn get_data_from_context ( context : & GraphExecutionContext , index : usize ) -> Vec < u8 > {
86+ // Use 1MB as the maximum output buffer size for audio output.
87+ const MAX_OUTPUT_BUFFER_SIZE : usize = 1024 * 1024 ;
9288 let mut output_buffer = vec ! [ 0u8 ; MAX_OUTPUT_BUFFER_SIZE ] ;
9389 let mut output_size = context
9490 . get_output ( index, & mut output_buffer)
9591 . expect ( "Failed to get output" ) ;
9692 output_size = std:: cmp:: min ( MAX_OUTPUT_BUFFER_SIZE , output_size) ;
9793
98- String :: from_utf8_lossy ( & output_buffer[ ..output_size] ) . to_string ( )
94+ output_buffer[ ..output_size] . to_vec ( )
9995}
10096
10197fn get_metadata_from_context ( context : & GraphExecutionContext ) -> Value {
102- serde_json:: from_str ( & get_data_from_context ( context, 1 ) ) . expect ( "Failed to get metadata" )
98+ serde_json:: from_str ( & String :: from_utf8_lossy ( & get_data_from_context ( context, 1 ) ) . to_string ( ) )
99+ . expect ( "Failed to get metadata" )
103100}
104101
105102fn main ( ) {
106103 let args: Vec < String > = env:: args ( ) . collect ( ) ;
104+ if args. len ( ) < 3 {
105+ println ! ( "Usage: {} <nn-preload-model> <prompt>" , args[ 0 ] ) ;
106+ return ;
107+ }
107108 let model_name: & str = & args[ 1 ] ;
108109
109110 // Set options for the graph. Check our README for more details:
@@ -121,53 +122,40 @@ fn main() {
121122
122123 // If there is a third argument, use it as the prompt and enter non-interactive mode.
123124 // This is mainly for the CI workflow.
124- if args. len ( ) >= 3 {
125- let prompt = & args[ 2 ] ;
126- // Set the prompt.
127- println ! ( "Prompt:\n {}" , prompt) ;
128- let tensor_data = prompt. as_bytes ( ) . to_vec ( ) ;
129- context
130- . set_input ( 0 , TensorType :: U8 , & [ 1 ] , & tensor_data)
131- . expect ( "Failed to set input" ) ;
132-
133- // Get the number of input tokens and llama.cpp versions.
134- let input_metadata = get_metadata_from_context ( & context) ;
135- println ! ( "[INFO] llama_commit: {}" , input_metadata[ "llama_commit" ] ) ;
136- println ! (
137- "[INFO] llama_build_number: {}" ,
138- input_metadata[ "llama_build_number" ]
139- ) ;
140- println ! (
141- "[INFO] Number of input tokens: {}" ,
142- input_metadata[ "input_tokens" ]
143- ) ;
144-
145- context. compute ( ) . expect ( "Failed to compute" ) ;
146- println ! ( "[INFO] Write output file to {}" , options[ "tts-output-file" ] ) ;
147-
148- return ;
149- }
150-
151- println ! ( "Text:" ) ;
152- let input = read_input ( ) ;
153-
154- // Set prompt to the input tensor.
155- set_data_to_context ( & mut context, input. as_bytes ( ) . to_vec ( ) ) . expect ( "Failed to set input" ) ;
156-
157- // Execute the inference.
158- match context. compute ( ) {
159- Ok ( _) => ( ) ,
160- Err ( Error :: BackendError ( BackendError :: ContextFull ) ) => {
161- println ! ( "\n [INFO] Context full." ) ;
162- }
163- Err ( Error :: BackendError ( BackendError :: PromptTooLong ) ) => {
164- println ! ( "\n [INFO] Prompt too long." ) ;
165- }
166- Err ( err) => {
167- println ! ( "\n [ERROR] {}" , err) ;
168- std:: process:: exit ( 1 ) ;
169- }
170- }
171-
172- println ! ( "[INFO] Write output file to {}" , options[ "tts-output-file" ] ) ;
125+ let prompt = & args[ 2 ] ;
126+ // Set the prompt.
127+ println ! ( "Prompt:\n {}" , prompt) ;
128+ let tensor_data = prompt. as_bytes ( ) . to_vec ( ) ;
129+ context
130+ . set_input ( 0 , TensorType :: U8 , & [ 1 ] , & tensor_data)
131+ . expect ( "Failed to set input" ) ;
132+
133+ // Get the number of input tokens and llama.cpp versions.
134+ let input_metadata = get_metadata_from_context ( & context) ;
135+ println ! ( "[INFO] llama_commit: {}" , input_metadata[ "llama_commit" ] ) ;
136+ println ! (
137+ "[INFO] llama_build_number: {}" ,
138+ input_metadata[ "llama_build_number" ]
139+ ) ;
140+ println ! (
141+ "[INFO] Number of input tokens: {}" ,
142+ input_metadata[ "input_tokens" ]
143+ ) ;
144+
145+ context. compute ( ) . expect ( "Failed to compute" ) ;
146+ println ! (
147+ "[INFO] Plugin writes output to file {}" ,
148+ options[ "tts-output-file" ]
149+ ) ;
150+
151+ // Write output buffer to file, should be the same as the output file in the options.
152+ let output_filename = "output-buffer.wav" ;
153+ let output_bytes = get_data_from_context ( & context, 0 ) ;
154+ let mut output_file = File :: create ( output_filename) . expect ( "Failed to create output file" ) ;
155+ output_file
156+ . write_all ( & output_bytes)
157+ . expect ( "Failed to write output file" ) ;
158+ println ! ( "[INFO] Write output buffer to file {}" , output_filename) ;
159+
160+ return ;
173161}
0 commit comments