Skip to content

Commit 3284605

Browse files
committed
feat: new record api
1 parent e08d456 commit 3284605

File tree

5 files changed

+144
-48
lines changed

5 files changed

+144
-48
lines changed

src/config.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,19 @@ pub struct Config {
143143

144144
pub hello_wav: Option<String>,
145145

146+
#[serde(default)]
147+
pub record: RecordConfig,
148+
146149
#[serde(flatten)]
147150
pub config: AIConfig,
148151
}
149152

153+
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
154+
pub struct RecordConfig {
155+
#[serde(default)]
156+
pub callback_url: Option<String>,
157+
}
158+
150159
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
151160
#[serde(untagged)]
152161
pub enum AIConfig {

src/main.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,19 @@ async fn routes(
9090
let mut router = Router::new()
9191
// .route("/", get(handler))
9292
.route("/ws/{id}", any(services::ws::ws_handler))
93-
.nest("/record", services::file::new_file_service("./record"))
93+
.route("/v1/chat/{id}", any(services::ws::ws_handler))
94+
.route("/v1/record/{id}", any(services::ws_record::ws_handler))
95+
.nest("/downloads", services::file::new_file_service("./record"))
9496
.layer(axum::Extension(Arc::new(services::ws::WsSetting::new(
9597
hello_wav,
9698
config.config,
9799
tool_set,
98-
))));
100+
))))
101+
.layer(axum::Extension(Arc::new(
102+
services::ws_record::WsRecordSetting {
103+
record_callback_url: config.record.callback_url,
104+
},
105+
)));
99106

100107
if let Some(real_config) = real_config {
101108
log::info!(

src/services/ws.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ pub enum ClientMsg {
892892
// return: wav data
893893
async fn process_socket_io(
894894
rx: &mut WsRx,
895-
audio_tx: tokio::sync::mpsc::Sender<ClientMsg>,
895+
audio_tx: ClientTx,
896896
socket: &mut WebSocket,
897897
) -> anyhow::Result<()> {
898898
loop {

src/services/ws_record.rs

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,54 +10,66 @@ use axum::{
1010
};
1111
use bytes::Bytes;
1212

13-
use crate::services::ws::WsSetting;
13+
pub struct WsRecordSetting {
14+
pub record_callback_url: Option<String>,
15+
}
16+
17+
async fn post_to_callback_url(callback_url: &str, id: &str, file_path: &str) -> anyhow::Result<()> {
18+
let client = reqwest::Client::new();
19+
20+
let resp = client
21+
.post(callback_url)
22+
.json(&serde_json::json!({
23+
"id": id,
24+
"download_uri": format!("/downloads/{}", file_path),
25+
}))
26+
.send()
27+
.await?;
28+
29+
log::info!(
30+
"[Record] {} callback to {} success: {}",
31+
id,
32+
callback_url,
33+
resp.status()
34+
);
35+
36+
Ok(())
37+
}
1438

1539
pub async fn ws_handler(
16-
Extension(pool): Extension<Arc<WsSetting>>,
40+
Extension(setting): Extension<Arc<WsRecordSetting>>,
1741
ws: WebSocketUpgrade,
1842
Path(id): Path<String>,
1943
) -> impl IntoResponse {
2044
let request_id = uuid::Uuid::new_v4().as_u128();
2145
log::info!("[Record] {id}:{request_id:x} connected.");
2246

2347
ws.on_upgrade(move |socket| async move {
24-
let id = id.clone();
25-
let pool = pool.clone();
26-
if let Err(e) = handle_socket(socket, &id, pool.clone()).await {
27-
log::error!("{id}:{request_id:x} error: {e}");
28-
};
48+
match handle_socket(socket, &id).await {
49+
Ok(file_path) => {
50+
if let Some(callback_url) = &setting.record_callback_url {
51+
if let Err(e) = post_to_callback_url(callback_url, &id, &file_path).await {
52+
log::error!("[Record] {} callback to {} failed: {}", id, callback_url, e);
53+
}
54+
}
55+
}
56+
Err(e) => {
57+
log::error!("{id}:{request_id:x} error: {e}");
58+
}
59+
}
2960
log::info!("{id}:{request_id:x} disconnected.");
3061
})
3162
}
3263

3364
enum ProcessMessageResult {
3465
Audio(Bytes),
35-
Submit,
36-
Text(String),
37-
StartRecord,
38-
StartChat,
3966
Close,
4067
Skip,
4168
}
4269

4370
fn process_message(msg: Message) -> ProcessMessageResult {
4471
match msg {
45-
Message::Text(t) => {
46-
if let Ok(cmd) = serde_json::from_str::<crate::protocol::ClientCommand>(&t) {
47-
match cmd {
48-
crate::protocol::ClientCommand::StartRecord => {
49-
ProcessMessageResult::StartRecord
50-
}
51-
crate::protocol::ClientCommand::StartChat => ProcessMessageResult::StartChat,
52-
crate::protocol::ClientCommand::Submit => ProcessMessageResult::Submit,
53-
crate::protocol::ClientCommand::Text { input } => {
54-
ProcessMessageResult::Text(input)
55-
}
56-
}
57-
} else {
58-
ProcessMessageResult::Skip
59-
}
60-
}
72+
Message::Text(_) => ProcessMessageResult::Skip,
6173
Message::Binary(d) => ProcessMessageResult::Audio(d),
6274
Message::Close(c) => {
6375
if let Some(cf) = c {
@@ -76,29 +88,35 @@ fn process_message(msg: Message) -> ProcessMessageResult {
7688
}
7789
}
7890

79-
// TODO: implement recording logic
80-
async fn handle_socket(
81-
mut socket: WebSocket,
82-
id: &str,
83-
pool: Arc<WsSetting>,
84-
) -> anyhow::Result<()> {
85-
std::fs::create_dir_all(format!("./record/{id}"))?;
91+
async fn handle_socket(mut socket: WebSocket, id: &str) -> anyhow::Result<String> {
92+
let now = chrono::Local::now();
93+
let date_str = now.format("%Y%m%d_%H%M%S%z").to_string();
94+
let file_path = format!("{id}/record_{date_str}.wav");
95+
let path = format!("./record/{file_path}");
8696

87-
while let Some(message) = socket.recv().await {
88-
let message = message.map_err(|e| anyhow::anyhow!("recv ws error: {e}"))?;
97+
let mut wav_file = crate::util::UnlimitedWavFileWriter::new(
98+
&path,
99+
crate::util::WavConfig {
100+
sample_rate: 16000,
101+
channels: 1,
102+
bits_per_sample: 16,
103+
},
104+
)
105+
.await?;
89106

107+
while let Ok(Some(Ok(message))) =
108+
tokio::time::timeout(std::time::Duration::from_secs(60), socket.recv()).await
109+
{
90110
match process_message(message) {
91-
ProcessMessageResult::Audio(_) => {}
92-
ProcessMessageResult::Submit => {}
93-
ProcessMessageResult::Text(_) => {}
94-
ProcessMessageResult::Skip => {}
95-
ProcessMessageResult::StartRecord => {}
96-
ProcessMessageResult::StartChat => {}
111+
ProcessMessageResult::Audio(chunk) => {
112+
wav_file.write_pcm_data(&chunk).await?;
113+
}
97114
ProcessMessageResult::Close => {
98115
return Err(anyhow::anyhow!("ws closed"));
99116
}
117+
ProcessMessageResult::Skip => {}
100118
}
101119
}
102120

103-
Ok(())
121+
Ok(file_path)
104122
}

src/util.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub fn pcm_to_wav(pcm_data: &[u8], config: WavConfig) -> Vec<u8> {
3434
cursor.write_all(&file_size.to_le_bytes()).unwrap(); // ChunkSize (little-endian)
3535
cursor.write_all(b"WAVE").unwrap(); // Format
3636

37-
// fmt 子块
37+
// fmt subchunk
3838
cursor.write_all(b"fmt ").unwrap(); // Subchunk1ID
3939
cursor.write_all(&16u32.to_le_bytes()).unwrap(); // Subchunk1Size (PCM = 16)
4040
cursor.write_all(&1u16.to_le_bytes()).unwrap(); // AudioFormat (PCM = 1)
@@ -46,7 +46,7 @@ pub fn pcm_to_wav(pcm_data: &[u8], config: WavConfig) -> Vec<u8> {
4646
.write_all(&config.bits_per_sample.to_le_bytes())
4747
.unwrap(); // BitsPerSample
4848

49-
// data 子块
49+
// data subchunk
5050
cursor.write_all(b"data").unwrap(); // Subchunk2ID
5151
cursor.write_all(&data_size.to_le_bytes()).unwrap(); // Subchunk2Size
5252

@@ -279,3 +279,65 @@ pub fn get_samples_i16(reader: &mut wav_io::reader::Reader) -> Result<Vec<i16>,
279279
}
280280
Ok(result)
281281
}
282+
283+
pub struct UnlimitedWavFileWriter {
284+
pub config: WavConfig,
285+
pub file: tokio::fs::File,
286+
}
287+
288+
impl UnlimitedWavFileWriter {
289+
pub async fn new(path: &str, config: WavConfig) -> anyhow::Result<Self> {
290+
let file = tokio::fs::File::create_new(path).await.map_err(|e| {
291+
anyhow::anyhow!(
292+
"Failed to create wav file at path {}: {}",
293+
path,
294+
e.to_string()
295+
)
296+
})?;
297+
Ok(Self { config, file })
298+
}
299+
300+
pub async fn write_wav_header(&mut self) -> anyhow::Result<()> {
301+
use tokio::io::AsyncWriteExt;
302+
303+
let bytes_per_sample = self.config.bits_per_sample / 8;
304+
let byte_rate =
305+
self.config.sample_rate * self.config.channels as u32 * bytes_per_sample as u32;
306+
let block_align = self.config.channels * bytes_per_sample;
307+
let data_size = 0xFFFFFFFFu32; // unknown data size
308+
let file_size = 0x7FFFFFFFu32;
309+
310+
self.file.write_all(b"RIFF").await?;
311+
self.file.write_all(&file_size.to_le_bytes()).await?; // ChunkSize (little-endian)
312+
self.file.write_all(b"WAVE").await?; // Format
313+
314+
// fmt subchunk
315+
self.file.write_all(b"fmt ").await?; // Subchunk1ID
316+
self.file.write_all(&16u32.to_le_bytes()).await?; // Subchunk1Size (PCM = 16)
317+
self.file.write_all(&1u16.to_le_bytes()).await?; // AudioFormat (PCM = 1)
318+
self.file
319+
.write_all(&self.config.channels.to_le_bytes())
320+
.await?; // NumChannels
321+
self.file
322+
.write_all(&self.config.sample_rate.to_le_bytes())
323+
.await?; // SampleRate
324+
self.file.write_all(&byte_rate.to_le_bytes()).await?; // ByteRate
325+
self.file.write_all(&block_align.to_le_bytes()).await?; // BlockAlign
326+
self.file
327+
.write_all(&self.config.bits_per_sample.to_le_bytes())
328+
.await?; // BitsPerSample
329+
330+
// data subchunk
331+
self.file.write_all(b"data").await?; // Subchunk2ID
332+
self.file.write_all(&data_size.to_le_bytes()).await?; // Subchunk2Size
333+
334+
Ok(())
335+
}
336+
337+
pub async fn write_pcm_data(&mut self, pcm_data: &[u8]) -> anyhow::Result<()> {
338+
use tokio::io::AsyncWriteExt;
339+
340+
self.file.write_all(pcm_data).await?;
341+
Ok(())
342+
}
343+
}

0 commit comments

Comments
 (0)