Skip to content

Commit a29dedc

Browse files
grorge123hydai
authored andcommitted
[Example] ChatTTS: add example
1 parent de9c3e6 commit a29dedc

File tree

8 files changed

+145
-0
lines changed

8 files changed

+145
-0
lines changed

wasmedge-chatTTS/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "wasmedge-chattts"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
serde_json = "1.0"
8+
wasmedge-wasi-nn = {path = "../../wasmedge-wasi-nn/rust", version = "0.8.0"}
9+
hound = "3.4"

wasmedge-chatTTS/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
cargo build --target wasm32-wasi --release
2+
3+
wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge-chattts.wasm
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
3+
dim: 384
4+
5+
decoder_config:
6+
idim: ${dim}
7+
odim: ${dim}
8+
hidden: 512
9+
n_layer: 12
10+
bn_dim: 128
11+
12+
vq_config: null

wasmedge-chatTTS/config/dvae.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
3+
dim: 512
4+
decoder_config:
5+
idim: ${dim}
6+
odim: ${dim}
7+
n_layer: 12
8+
bn_dim: 128
9+
10+
vq_config:
11+
dim: 1024
12+
levels: [5,5,5,5]
13+
G: 2
14+
R: 2

wasmedge-chatTTS/config/gpt.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
3+
num_audio_tokens: 626
4+
num_text_tokens: 21178
5+
6+
gpt_config:
7+
hidden_size: 768
8+
intermediate_size: 3072
9+
num_attention_heads: 12
10+
num_hidden_layers: 20
11+
use_cache: False
12+
max_position_embeddings: 4096
13+
# attn_implementation: flash_attention_2
14+
15+
spk_emb_dim: 192
16+
spk_KL: False
17+
num_audio_tokens: 626
18+
num_text_tokens: null
19+
num_vq: 4
20+

wasmedge-chatTTS/config/path.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
3+
vocos_config_path: config/vocos.yaml
4+
vocos_ckpt_path: asset/Vocos.pt
5+
dvae_config_path: config/dvae.yaml
6+
dvae_ckpt_path: asset/DVAE.pt
7+
gpt_config_path: config/gpt.yaml
8+
gpt_ckpt_path: asset/GPT.pt
9+
decoder_config_path: config/decoder.yaml
10+
decoder_ckpt_path: asset/Decoder.pt
11+
tokenizer_path: asset/tokenizer.pt

wasmedge-chatTTS/config/vocos.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
feature_extractor:
2+
class_path: vocos.feature_extractors.MelSpectrogramFeatures
3+
init_args:
4+
sample_rate: 24000
5+
n_fft: 1024
6+
hop_length: 256
7+
n_mels: 100
8+
padding: center
9+
10+
backbone:
11+
class_path: vocos.models.VocosBackbone
12+
init_args:
13+
input_channels: 100
14+
dim: 512
15+
intermediate_dim: 1536
16+
num_layers: 8
17+
18+
head:
19+
class_path: vocos.heads.ISTFTHead
20+
init_args:
21+
dim: 512
22+
n_fft: 1024
23+
hop_length: 256
24+
padding: center

wasmedge-chatTTS/src/main.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use wasmedge_wasi_nn::{
2+
self, ExecutionTarget, GraphBuilder, GraphEncoding, GraphExecutionContext,
3+
TensorType,
4+
};
5+
use hound;
6+
7+
fn get_data_from_context(context: &GraphExecutionContext, index: usize, limit: usize) -> Vec<u8> {
8+
// Preserve for 4096 tokens with average token length 8
9+
const MAX_OUTPUT_BUFFER_SIZE: usize = 4096 * 4096;
10+
let mut output_buffer = vec![0u8; MAX_OUTPUT_BUFFER_SIZE];
11+
let _ = context
12+
.get_output(index, &mut output_buffer)
13+
.expect("Failed to get output");
14+
15+
return output_buffer[..limit].to_vec();
16+
}
17+
18+
fn main() {
19+
let prompt = "It is a test sentence.";
20+
let tensor_data = prompt.as_bytes().to_vec();
21+
let empty_vec: Vec<Vec<u8>> = Vec::new();
22+
let graph = GraphBuilder::new(GraphEncoding::ChatTTS, ExecutionTarget::CPU)
23+
.build_from_bytes(empty_vec)
24+
.expect("Failed to build graph");
25+
let mut context = graph
26+
.init_execution_context()
27+
.expect("Failed to init context");
28+
context
29+
.set_input(0, TensorType::U8, &[1], &tensor_data)
30+
.expect("Failed to set input");
31+
context.compute().expect("Failed to compute");
32+
let bytes_written = get_data_from_context(&context, 1, 4);
33+
let bytes_written = usize::from_le_bytes(bytes_written.as_slice().try_into().unwrap());
34+
println!("Byte: {}", bytes_written);
35+
let output_bytes = get_data_from_context(&context, 0, bytes_written);
36+
let spec = hound::WavSpec {
37+
channels: 1,
38+
sample_rate: 24000,
39+
bits_per_sample: 32,
40+
sample_format: hound::SampleFormat::Float,
41+
};
42+
let mut writer = hound::WavWriter::create("output1.wav", spec).unwrap();
43+
let samples: Vec<f32> = output_bytes
44+
.chunks_exact(4)
45+
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
46+
.collect();
47+
for sample in samples {
48+
writer.write_sample(sample).unwrap();
49+
}
50+
writer.finalize().unwrap();
51+
graph.unload().expect("Failed to free resource");
52+
}

0 commit comments

Comments
 (0)