Skip to content

Commit 70af7bf

Browse files
juntaoclaude
andcommitted
Apply cargo fmt formatting across all source files
Signed-off-by: Michael Yuan <michael@secondstate.io> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 66076fc commit 70af7bf

File tree

15 files changed

+244
-172
lines changed

15 files changed

+244
-172
lines changed

build.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ fn build_mlx() {
5555
// Ensure CMake and Rust agree on the macOS deployment target.
5656
// Without this, CMake may compile C++ for macOS 15.x while Rust links
5757
// for macOS 11.0, causing `___isPlatformVersionAtLeast` linker errors.
58-
let deployment_target = std::env::var("MACOSX_DEPLOYMENT_TARGET")
59-
.unwrap_or_else(|_| "14.0".to_string());
58+
let deployment_target =
59+
std::env::var("MACOSX_DEPLOYMENT_TARGET").unwrap_or_else(|_| "14.0".to_string());
6060

6161
// Build mlx-c via CMake (fetches and builds MLX C++ as a dependency)
6262
let dst = cmake::Config::new(&mlx_c_dir)

src/audio.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ pub fn load_audio(path: impl AsRef<Path>, target_sr: usize) -> Result<Vec<f32>>
7171
}
7272

7373
let probed = symphonia::default::get_probe()
74-
.format(&hint, mss, &FormatOptions::default(), &MetadataOptions::default())
74+
.format(
75+
&hint,
76+
mss,
77+
&FormatOptions::default(),
78+
&MetadataOptions::default(),
79+
)
7580
.context("Unsupported audio format")?;
7681

7782
let mut format = probed.format;
@@ -85,11 +90,7 @@ pub fn load_audio(path: impl AsRef<Path>, target_sr: usize) -> Result<Vec<f32>>
8590
.codec_params
8691
.sample_rate
8792
.context("Unknown sample rate")? as usize;
88-
let channels = track
89-
.codec_params
90-
.channels
91-
.map(|c| c.count())
92-
.unwrap_or(1);
93+
let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(1);
9394

9495
let track_id = track.id;
9596
let mut decoder = symphonia::default::get_codecs()
@@ -226,7 +227,13 @@ fn resample(input: &[f32], src_sr: usize, dst_sr: usize) -> Result<Vec<f32>> {
226227

227228
/// Compute mel filterbank matrix: (n_mels, n_fft/2+1).
228229
/// Follows librosa.filters.mel with norm='slaney'.
229-
pub fn mel_filterbank(sample_rate: usize, n_fft: usize, n_mels: usize, fmin: f64, fmax: f64) -> Vec<f32> {
230+
pub fn mel_filterbank(
231+
sample_rate: usize,
232+
n_fft: usize,
233+
n_mels: usize,
234+
fmin: f64,
235+
fmax: f64,
236+
) -> Vec<f32> {
230237
let n_freqs = n_fft / 2 + 1;
231238

232239
let hz_to_mel = |f: f64| -> f64 { 2595.0 * (1.0 + f / 700.0).log10() };

src/bin/server.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,12 @@ async fn transcribe_audio(
257257
let mut pos = 0;
258258
while pos < samples.len() {
259259
let end = (pos + max_samples).min(samples.len());
260-
parts.push(transcribe_chunk(&samples[pos..end], &state, &lang, punctuation)?);
260+
parts.push(transcribe_chunk(
261+
&samples[pos..end],
262+
&state,
263+
&lang,
264+
punctuation,
265+
)?);
261266
if end >= samples.len() {
262267
break;
263268
}
@@ -303,7 +308,10 @@ async fn transcribe_audio(
303308
);
304309
(
305310
StatusCode::OK,
306-
[(axum::http::header::CONTENT_TYPE, "text/plain; charset=utf-8")],
311+
[(
312+
axum::http::header::CONTENT_TYPE,
313+
"text/plain; charset=utf-8",
314+
)],
307315
srt,
308316
)
309317
.into_response()
@@ -421,8 +429,7 @@ fn add_dither(samples: &[f32], dither: f32, seed: u64) -> Vec<f32> {
421429
.wrapping_mul(6364136223846793005)
422430
.wrapping_add(1442695040888963407);
423431
let v = (rng >> 33) as f32 / (u32::MAX as f32);
424-
let noise =
425-
(-2.0 * u.max(1e-38).ln()).sqrt() * (2.0 * std::f32::consts::PI * v).cos();
432+
let noise = (-2.0 * u.max(1e-38).ln()).sqrt() * (2.0 * std::f32::consts::PI * v).cos();
426433
*s += dither * noise;
427434
}
428435
out
@@ -435,7 +442,11 @@ fn add_dither(samples: &[f32], dither: f32, seed: u64) -> Vec<f32> {
435442
fn load_model(args: &Args) -> Result<ModelState> {
436443
let model_dir = &args.model_dir;
437444

438-
anyhow::ensure!(model_dir.exists(), "Model directory not found: {:?}", model_dir);
445+
anyhow::ensure!(
446+
model_dir.exists(),
447+
"Model directory not found: {:?}",
448+
model_dir
449+
);
439450
for f in &["config.json", "model.safetensors", "vocab.json"] {
440451
anyhow::ensure!(
441452
model_dir.join(f).exists(),

src/decoder.rs

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,18 @@ impl DecoderAttn {
8181
let (b, t, _) = hidden_states.size3().unwrap();
8282
let s = source.size()[1];
8383

84-
let reshape_q = |z: &Tensor| -> Tensor {
85-
z.view([b, t, self.n_heads, self.head_dim]).transpose(1, 2)
86-
};
84+
let reshape_q =
85+
|z: &Tensor| -> Tensor { z.view([b, t, self.n_heads, self.head_dim]).transpose(1, 2) };
8786
let reshape_kv = |z: &Tensor, seq: i64| -> Tensor {
88-
z.view([b, seq, self.n_heads, self.head_dim]).transpose(1, 2)
87+
z.view([b, seq, self.n_heads, self.head_dim])
88+
.transpose(1, 2)
8989
};
9090

9191
let q = reshape_q(&linear(hidden_states, &self.q_w, &self.q_b));
9292
let k = reshape_kv(&linear(source, &self.k_w, &self.k_b), s);
9393
let v = reshape_kv(&linear(source, &self.v_w, &self.v_b), s);
9494
(q, k, v)
9595
}
96-
9796
}
9897

9998
// ---------------------------------------------------------------------------
@@ -120,7 +119,11 @@ impl DecoderFFN {
120119
}
121120

122121
fn forward(&self, x: &Tensor) -> Tensor {
123-
linear(&linear(x, &self.dense_in_w, &self.dense_in_b).relu(), &self.dense_out_w, &self.dense_out_b)
122+
linear(
123+
&linear(x, &self.dense_in_w, &self.dense_in_b).relu(),
124+
&self.dense_out_w,
125+
&self.dense_out_b,
126+
)
124127
}
125128
}
126129

@@ -187,10 +190,7 @@ impl DecoderLayer {
187190
let (q_new, k_new, v_new) = self.self_attn.project_qkv(&normed, &normed);
188191

189192
let (k_full, v_full) = match (self_k_cache, self_v_cache) {
190-
(Some(kc), Some(vc)) => (
191-
Tensor::cat(&[kc, &k_new], 2),
192-
Tensor::cat(&[vc, &v_new], 2),
193-
),
193+
(Some(kc), Some(vc)) => (Tensor::cat(&[kc, &k_new], 2), Tensor::cat(&[vc, &v_new], 2)),
194194
_ => (k_new.shallow_clone(), v_new.shallow_clone()),
195195
};
196196

@@ -251,7 +251,7 @@ impl FixedPosEnc {
251251
// TransformerDecoder (public)
252252
// ---------------------------------------------------------------------------
253253
pub struct TransformerDecoder {
254-
token_emb: Tensor, // (vocab, hidden)
254+
token_emb: Tensor, // (vocab, hidden)
255255
pos_enc: FixedPosEnc,
256256
emb_norm_w: Tensor,
257257
emb_norm_b: Tensor,
@@ -302,9 +302,7 @@ impl TransformerDecoder {
302302
let head_w = weights
303303
.get("log_softmax.mlp.layer0.weight")?
304304
.shallow_clone();
305-
let head_b = weights
306-
.get("log_softmax.mlp.layer0.bias")?
307-
.shallow_clone();
305+
let head_b = weights.get("log_softmax.mlp.layer0.bias")?.shallow_clone();
308306

309307
Ok(Self {
310308
token_emb,
@@ -356,10 +354,11 @@ impl TransformerDecoder {
356354
) -> (Vec<f32>, Vec<(Option<Tensor>, Option<Tensor>)>) {
357355
let ids = Tensor::from_slice(&[token_id]);
358356
let emb = self.token_emb.index_select(0, &ids).unsqueeze(0); // (1, 1, hidden)
359-
let pe = self.pos_enc.forward(&[position]).unsqueeze(0); // (1, 1, hidden)
357+
let pe = self.pos_enc.forward(&[position]).unsqueeze(0); // (1, 1, hidden)
360358
let x = layer_norm(&(emb + pe), &self.emb_norm_w, &self.emb_norm_b);
361359

362-
let mut new_kv: Vec<(Option<Tensor>, Option<Tensor>)> = Vec::with_capacity(self.layers.len());
360+
let mut new_kv: Vec<(Option<Tensor>, Option<Tensor>)> =
361+
Vec::with_capacity(self.layers.len());
363362
let mut hidden = x;
364363

365364
for (i, layer) in self.layers.iter().enumerate() {

src/encoder.rs

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ struct ConvSubsampling {
6868

6969
impl ConvSubsampling {
7070
fn load(weights: &Weights, prefix: &str) -> Result<Self> {
71-
let w = |n: &str| -> Result<Tensor> { Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone()) };
71+
let w = |n: &str| -> Result<Tensor> {
72+
Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone())
73+
};
7274
Ok(Self {
7375
c0_w: w("conv.0.weight")?,
7476
c0_b: w("conv.0.bias")?,
@@ -169,7 +171,9 @@ struct FeedForward {
169171

170172
impl FeedForward {
171173
fn load(weights: &Weights, prefix: &str) -> Result<Self> {
172-
let w = |n: &str| -> Result<Tensor> { Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone()) };
174+
let w = |n: &str| -> Result<Tensor> {
175+
Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone())
176+
};
173177
Ok(Self {
174178
l1_w: w("linear1.weight")?,
175179
l1_b: w("linear1.bias")?,
@@ -204,7 +208,9 @@ struct ConformerConv {
204208

205209
impl ConformerConv {
206210
fn load(weights: &Weights, prefix: &str, d_model: i64) -> Result<Self> {
207-
let w = |n: &str| -> Result<Tensor> { Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone()) };
211+
let w = |n: &str| -> Result<Tensor> {
212+
Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone())
213+
};
208214
Ok(Self {
209215
pw1_w: w("pointwise_conv1.weight")?,
210216
pw1_b: w("pointwise_conv1.bias")?,
@@ -237,8 +243,14 @@ impl ConformerConv {
237243

238244
// Depthwise conv
239245
let pad = (kernel_size - 1) / 2;
240-
let x =
241-
x.conv1d(&self.dw_w, Some(&self.dw_b), &[1], &[pad], &[1], self.d_model);
246+
let x = x.conv1d(
247+
&self.dw_w,
248+
Some(&self.dw_b),
249+
&[1],
250+
&[pad],
251+
&[1],
252+
self.d_model,
253+
);
242254

243255
// BatchNorm (eval mode)
244256
let x = batch_norm_eval(&x, &self.bn_w, &self.bn_b, &self.bn_rm, &self.bn_rv);
@@ -277,7 +289,9 @@ struct RelPosAttn {
277289
impl RelPosAttn {
278290
fn load(weights: &Weights, prefix: &str, n_heads: i64, d_model: i64) -> Result<Self> {
279291
let d_k = d_model / n_heads;
280-
let w = |n: &str| -> Result<Tensor> { Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone()) };
292+
let w = |n: &str| -> Result<Tensor> {
293+
Ok(weights.get(&format!("{}{}", prefix, n))?.shallow_clone())
294+
};
281295
Ok(Self {
282296
q_w: w("linear_q.weight")?,
283297
q_b: w("linear_q.bias")?,
@@ -309,19 +323,22 @@ impl RelPosAttn {
309323
fn forward(&self, x: &Tensor, pos_emb: &Tensor) -> Tensor {
310324
let (b, t, _) = x.size3().unwrap();
311325

312-
let reshape = |z: &Tensor| -> Tensor {
313-
z.view([b, t, self.n_heads, self.d_k]).transpose(1, 2)
314-
};
326+
let reshape =
327+
|z: &Tensor| -> Tensor { z.view([b, t, self.n_heads, self.d_k]).transpose(1, 2) };
315328

316329
let q = reshape(&linear(x, &self.q_w, &self.q_b));
317330
let k = reshape(&linear(x, &self.k_w, &self.k_b));
318331
let v = reshape(&linear(x, &self.v_w, &self.v_b));
319332

320333
// pos_emb: (1, 2T-1, d_model) → (1, 2T-1, H, d_k) → (1, H, 2T-1, d_k)
321334
let n_pos = pos_emb.size()[1];
322-
let p = linear(pos_emb, &self.pos_w, &Tensor::zeros(&[1], (Kind::Float, x.device())))
323-
.view([1, n_pos, self.n_heads, self.d_k])
324-
.transpose(1, 2);
335+
let p = linear(
336+
pos_emb,
337+
&self.pos_w,
338+
&Tensor::zeros(&[1], (Kind::Float, x.device())),
339+
)
340+
.view([1, n_pos, self.n_heads, self.d_k])
341+
.transpose(1, 2);
325342

326343
// pos_bias_u/v: (n_heads, d_k) → (1, n_heads, 1, d_k) for broadcasting
327344
let u = self.pos_bias_u.view([1, self.n_heads, 1, self.d_k]);
@@ -384,11 +401,7 @@ impl ConformerLayer {
384401
d_model,
385402
)?,
386403
norm_conv: norm("norm_conv")?,
387-
conv: ConformerConv::load(
388-
weights,
389-
&format!("{}conv.", prefix),
390-
d_model,
391-
)?,
404+
conv: ConformerConv::load(weights, &format!("{}conv.", prefix), d_model)?,
392405
norm_ff2: norm("norm_feed_forward2")?,
393406
ff2: FeedForward::load(weights, &format!("{}feed_forward2.", prefix))?,
394407
norm_out: norm("norm_out")?,
@@ -397,13 +410,27 @@ impl ConformerLayer {
397410

398411
fn forward(&self, x: &Tensor, pos_emb: &Tensor) -> Tensor {
399412
// FF1 (½-scaled)
400-
let x = x + 0.5 * self.ff1.forward(&layer_norm(x, &self.norm_ff1.0, &self.norm_ff1.1));
413+
let x = x + 0.5
414+
* self
415+
.ff1
416+
.forward(&layer_norm(x, &self.norm_ff1.0, &self.norm_ff1.1));
401417
// Self-attention
402-
let x = &x + self.self_attn.forward(&layer_norm(&x, &self.norm_self_att.0, &self.norm_self_att.1), pos_emb);
418+
let x = &x
419+
+ self.self_attn.forward(
420+
&layer_norm(&x, &self.norm_self_att.0, &self.norm_self_att.1),
421+
pos_emb,
422+
);
403423
// Conformer conv
404-
let x = &x + self.conv.forward(&layer_norm(&x, &self.norm_conv.0, &self.norm_conv.1));
424+
let x = &x
425+
+ self
426+
.conv
427+
.forward(&layer_norm(&x, &self.norm_conv.0, &self.norm_conv.1));
405428
// FF2 (½-scaled)
406-
let x = &x + 0.5 * self.ff2.forward(&layer_norm(&x, &self.norm_ff2.0, &self.norm_ff2.1));
429+
let x = &x
430+
+ 0.5
431+
* self
432+
.ff2
433+
.forward(&layer_norm(&x, &self.norm_ff2.0, &self.norm_ff2.1));
407434
// Final norm
408435
layer_norm(&x, &self.norm_out.0, &self.norm_out.1)
409436
}
@@ -437,15 +464,14 @@ impl ConformerEncoder {
437464
}
438465

439466
// Encoder→decoder projection (Linear 1280 → 1024)
440-
let (enc_dec_proj_w, enc_dec_proj_b) =
441-
if let (Ok(w), Ok(b)) = (
442-
weights.get("encoder_decoder_proj.weight"),
443-
weights.get("encoder_decoder_proj.bias"),
444-
) {
445-
(Some(w.shallow_clone()), Some(b.shallow_clone()))
446-
} else {
447-
(None, None)
448-
};
467+
let (enc_dec_proj_w, enc_dec_proj_b) = if let (Ok(w), Ok(b)) = (
468+
weights.get("encoder_decoder_proj.weight"),
469+
weights.get("encoder_decoder_proj.bias"),
470+
) {
471+
(Some(w.shallow_clone()), Some(b.shallow_clone()))
472+
} else {
473+
(None, None)
474+
};
449475

450476
Ok(Self {
451477
pre_encode,

src/inference.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ pub fn transcribe(
3939

4040
for (i, &token_id) in prompt.iter().enumerate() {
4141
let position = i as i64;
42-
let (logits, new_kv) =
43-
decoder.step(token_id, position, &self_kv_cache, &cross_kv);
42+
let (logits, new_kv) = decoder.step(token_id, position, &self_kv_cache, &cross_kv);
4443
self_kv_cache = new_kv;
4544
last_logits = logits;
4645
}
@@ -59,8 +58,7 @@ pub fn transcribe(
5958
}
6059
generated.push(next_token);
6160

62-
let (logits, new_kv) =
63-
decoder.step(next_token, position, &self_kv_cache, &cross_kv);
61+
let (logits, new_kv) = decoder.step(next_token, position, &self_kv_cache, &cross_kv);
6462
self_kv_cache = new_kv;
6563
last_logits = logits;
6664
position += 1;

src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ pub mod config;
88
pub mod tokenizer;
99

1010
#[cfg(feature = "tch-backend")]
11-
pub mod weights;
11+
pub mod decoder;
1212
#[cfg(feature = "tch-backend")]
1313
pub mod encoder;
1414
#[cfg(feature = "tch-backend")]
15-
pub mod decoder;
16-
#[cfg(feature = "tch-backend")]
1715
pub mod inference;
16+
#[cfg(feature = "tch-backend")]
17+
pub mod weights;
1818

1919
#[cfg(feature = "mlx")]
2020
pub mod mlx;

0 commit comments

Comments
 (0)