Skip to content

Commit 9d89967

Browse files
authored
feat(tts): add tts_speaker_file support (#179)
* feat(tts): add `tts_speaker_file` support Signed-off-by: dm4 <[email protected]> * feat(llama): add n_predict support Signed-off-by: dm4 <[email protected]> --------- Signed-off-by: dm4 <[email protected]>
1 parent 52e6365 commit 9d89967

File tree

6 files changed

+73
-68
lines changed

6 files changed

+73
-68
lines changed

.github/workflows/llama.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,19 @@ jobs:
256256
cd wasmedge-ggml/tts
257257
curl -LO https://huggingface.co/second-state/OuteTTS-0.2-500M-GGUF/resolve/main/OuteTTS-0.2-500M-Q5_K_M.gguf
258258
curl -LO https://huggingface.co/second-state/OuteTTS-0.2-500M-GGUF/resolve/main/wavtokenizer-large-75-ggml-f16.gguf
259+
curl -LO https://raw.githubusercontent.com/edwko/OuteTTS/refs/heads/main/outetts/version/v1/default_speakers/en_male_1.json
259260
cargo build --target wasm32-wasip1 --release
260261
time wasmedge --dir .:. \
261262
--env n_gpu_layers="$NGL" \
262263
--nn-preload default:GGML:AUTO:OuteTTS-0.2-500M-Q5_K_M.gguf \
263264
--env tts=true \
264265
--env tts_output_file=output.wav \
266+
--env tts_speaker_file=en_male_1.json \
265267
--env model_vocoder=wavtokenizer-large-75-ggml-f16.gguf \
266268
target/wasm32-wasip1/release/wasmedge-ggml-tts.wasm \
267269
default \
268270
'Hello, world.'
271+
sha1sum *.wav
269272
270273
- name: Build llama-stream
271274
run: |

wasmedge-ggml/llama/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ fn get_options_from_env() -> Value {
3939
} else {
4040
options["n-gpu-layers"] = serde_json::from_str("0").unwrap()
4141
}
42+
if let Ok(val) = env::var("n_predict") {
43+
options["n-predict"] =
44+
serde_json::from_str(val.as_str()).expect("invalid n_predict value (unsigned integer")
45+
}
4246
options["ctx-size"] = serde_json::from_str("1024").unwrap();
4347

4448
options
@@ -143,7 +147,7 @@ fn main() {
143147
"[INFO] Number of output tokens: {}",
144148
metadata["output_tokens"]
145149
);
146-
std::process::exit(0);
150+
return;
147151
}
148152

149153
let mut saved_prompt = String::new();
-1.56 MB
Binary file not shown.

wasmedge-ggml/tts/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,22 @@ wget https://huggingface.co/second-state/OuteTTS-0.2-500M-GGUF/resolve/main/Oute
77
wget https://huggingface.co/second-state/OuteTTS-0.2-500M-GGUF/resolve/main/wavtokenizer-large-75-ggml-f16.gguf
88
```
99

10+
## Speaker Profile Download
11+
12+
```console
13+
wget https://raw.githubusercontent.com/edwko/OuteTTS/refs/heads/main/outetts/version/v1/default_speakers/en_male_1.json
14+
```
15+
16+
> [!NOTE]
17+
> The default speaker profile of the plugin is `en_female_1.json`.
18+
1019
### Execution
1120

1221
```console
1322
$ wasmedge --dir .:. \
1423
--env tts=true \
1524
--env tts_output_file=output.wav \
25+
--env tts_speaker_file=en_male_1.json \
1626
--env model_vocoder=wavtokenizer-large-75-ggml-f16.gguf \
1727
--nn-preload default:GGML:AUTO:OuteTTS-0.2-500M-Q5_K_M.gguf \
1828
wasmedge-ggml-tts.wasm default 'Hello, world.'

wasmedge-ggml/tts/src/main.rs

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,13 @@
11
use serde_json::Value;
22
use std::collections::HashMap;
33
use std::env;
4-
use std::io;
4+
use std::fs::File;
5+
use std::io::{self, Write};
56
use wasmedge_wasi_nn::{
67
self, BackendError, Error, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
78
TensorType,
89
};
910

10-
fn read_input() -> String {
11-
loop {
12-
let mut answer = String::new();
13-
io::stdin()
14-
.read_line(&mut answer)
15-
.expect("Failed to read line");
16-
if !answer.is_empty() && answer != "\n" && answer != "\r\n" {
17-
return answer.trim().to_string();
18-
}
19-
}
20-
}
21-
2211
fn get_options_from_env() -> HashMap<&'static str, Value> {
2312
let mut options = HashMap::new();
2413

@@ -36,6 +25,10 @@ fn get_options_from_env() -> HashMap<&'static str, Value> {
3625
eprintln!("Failed to get vocoder model.");
3726
std::process::exit(1);
3827
}
28+
// Speaker profile is optional.
29+
if let Ok(val) = env::var("tts_speaker_file") {
30+
options.insert("tts-speaker-file", Value::from(val.as_str()));
31+
}
3932

4033
// Optional parameters
4134
if let Ok(val) = env::var("enable_log") {
@@ -79,31 +72,39 @@ fn get_options_from_env() -> HashMap<&'static str, Value> {
7972
if let Ok(val) = env::var("seed") {
8073
options.insert("seed", serde_json::from_str(val.as_str()).unwrap());
8174
}
75+
if let Ok(val) = env::var("temp") {
76+
options.insert("temp", serde_json::from_str(val.as_str()).unwrap());
77+
}
8278
options
8379
}
8480

8581
fn set_data_to_context(context: &mut GraphExecutionContext, data: Vec<u8>) -> Result<(), Error> {
8682
context.set_input(0, TensorType::U8, &[1], &data)
8783
}
8884

89-
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> String {
90-
// Preserve for 4096 tokens with average token length 6
91-
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 6;
85+
fn get_data_from_context(context: &GraphExecutionContext, index: usize) -> Vec<u8> {
86+
// Use 1MB as the maximum output buffer size for audio output.
87+
const MAX_OUTPUT_BUFFER_SIZE: usize = 1024 * 1024;
9288
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
9389
let mut output_size = context
9490
.get_output(index, &mut output_buffer)
9591
.expect("Failed to get output");
9692
output_size = std::cmp::min(MAX_OUTPUT_BUFFER_SIZE, output_size);
9793

98-
String::from_utf8_lossy(&output_buffer[..output_size]).to_string()
94+
output_buffer[..output_size].to_vec()
9995
}
10096

10197
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
102-
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
98+
serde_json::from_str(&String::from_utf8_lossy(&get_data_from_context(context, 1)).to_string())
99+
.expect("Failed to get metadata")
103100
}
104101

105102
fn main() {
106103
let args: Vec<String> = env::args().collect();
104+
if args.len() < 3 {
105+
println!("Usage: {} <nn-preload-model> <prompt>", args[0]);
106+
return;
107+
}
107108
let model_name: &str = &args[1];
108109

109110
// Set options for the graph. Check our README for more details:
@@ -121,53 +122,40 @@ fn main() {
121122

122123
// If there is a third argument, use it as the prompt and enter non-interactive mode.
123124
// This is mainly for the CI workflow.
124-
if args.len() >= 3 {
125-
let prompt = &args[2];
126-
// Set the prompt.
127-
println!("Prompt:\n{}", prompt);
128-
let tensor_data = prompt.as_bytes().to_vec();
129-
context
130-
.set_input(0, TensorType::U8, &[1], &tensor_data)
131-
.expect("Failed to set input");
132-
133-
// Get the number of input tokens and llama.cpp versions.
134-
let input_metadata = get_metadata_from_context(&context);
135-
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
136-
println!(
137-
"[INFO] llama_build_number: {}",
138-
input_metadata["llama_build_number"]
139-
);
140-
println!(
141-
"[INFO] Number of input tokens: {}",
142-
input_metadata["input_tokens"]
143-
);
144-
145-
context.compute().expect("Failed to compute");
146-
println!("[INFO] Write output file to {}", options["tts-output-file"]);
147-
148-
return;
149-
}
150-
151-
println!("Text:");
152-
let input = read_input();
153-
154-
// Set prompt to the input tensor.
155-
set_data_to_context(&mut context, input.as_bytes().to_vec()).expect("Failed to set input");
156-
157-
// Execute the inference.
158-
match context.compute() {
159-
Ok(_) => (),
160-
Err(Error::BackendError(BackendError::ContextFull)) => {
161-
println!("\n[INFO] Context full.");
162-
}
163-
Err(Error::BackendError(BackendError::PromptTooLong)) => {
164-
println!("\n[INFO] Prompt too long.");
165-
}
166-
Err(err) => {
167-
println!("\n[ERROR] {}", err);
168-
std::process::exit(1);
169-
}
170-
}
171-
172-
println!("[INFO] Write output file to {}", options["tts-output-file"]);
125+
let prompt = &args[2];
126+
// Set the prompt.
127+
println!("Prompt:\n{}", prompt);
128+
let tensor_data = prompt.as_bytes().to_vec();
129+
context
130+
.set_input(0, TensorType::U8, &[1], &tensor_data)
131+
.expect("Failed to set input");
132+
133+
// Get the number of input tokens and llama.cpp versions.
134+
let input_metadata = get_metadata_from_context(&context);
135+
println!("[INFO] llama_commit: {}", input_metadata["llama_commit"]);
136+
println!(
137+
"[INFO] llama_build_number: {}",
138+
input_metadata["llama_build_number"]
139+
);
140+
println!(
141+
"[INFO] Number of input tokens: {}",
142+
input_metadata["input_tokens"]
143+
);
144+
145+
context.compute().expect("Failed to compute");
146+
println!(
147+
"[INFO] Plugin writes output to file {}",
148+
options["tts-output-file"]
149+
);
150+
151+
// Write output buffer to file, should be the same as the output file in the options.
152+
let output_filename = "output-buffer.wav";
153+
let output_bytes = get_data_from_context(&context, 0);
154+
let mut output_file = File::create(output_filename).expect("Failed to create output file");
155+
output_file
156+
.write_all(&output_bytes)
157+
.expect("Failed to write output file");
158+
println!("[INFO] Write output buffer to file {}", output_filename);
159+
160+
return;
173161
}
7.52 KB
Binary file not shown.

0 commit comments

Comments
 (0)