Skip to content

Commit fd7d9e6

Browse files
juntaoclaude
andcommitted
Fix MLX 1-token output, tracing to stderr, and shallow_clone perf
- Fix ConvSubsampling NHWC flatten order: transpose (T',F,C) → (T',C,F) before flattening so the linear projection receives features in the same order as PyTorch's NCHW layout. This was the root cause of the model generating only 1 token on the MLX backend. - Direct tracing output to stderr in both CLI and server binaries so transcript text on stdout is not contaminated by log lines. - Replace shallow_clone() CPU round-trip (to_vec_f32 + from_data_f32) with mlx_array_set() for O(1) ref-counted sharing, eliminating the ~75s encoder construction overhead. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 21a2bee commit fd7d9e6

File tree

4 files changed

+11
-12
lines changed

4 files changed

+11
-12
lines changed

src/bin/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ async fn main() -> Result<()> {
535535
_ => "trace",
536536
};
537537
tracing_subscriber::fmt()
538+
.with_writer(std::io::stderr)
538539
.with_env_filter(
539540
tracing_subscriber::EnvFilter::try_from_default_env()
540541
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(log_level)),

src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ fn main() -> Result<()> {
6060
_ => "trace",
6161
};
6262
tracing_subscriber::fmt()
63+
.with_writer(std::io::stderr)
6364
.with_env_filter(
6465
tracing_subscriber::EnvFilter::try_from_default_env()
6566
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(log_level)),

src/mlx/array.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,12 @@ impl Drop for Array {
165165
// C API. Provide an explicit method instead of implementing Clone to avoid
166166
// accidental copies.
167167
impl Array {
168-
/// Shallow copy — wraps the same storage.
169-
/// The caller is responsible for ensuring the original outlives the copy
170-
/// (or that eval has been called, materialising the data).
171-
///
172-
/// TODO: use mlx_array_retain if/when available in mlx-c.
168+
/// Shallow copy — uses `mlx_array_set` to share the underlying storage
169+
/// with reference counting. This is O(1) and avoids the expensive CPU
170+
/// round-trip that the previous `to_vec_f32` + `from_data_f32` approach used.
173171
pub fn shallow_clone(&self) -> Self {
174-
// Re-create from data to guarantee independent ownership.
175-
// This is the safe fallback — the eval round-trip is acceptable for
176-
// weight tensors that are only cloned once at load time.
177-
let data = self.to_vec_f32();
178-
let shape = self.shape();
179-
Self::from_data_f32(&data, &shape)
172+
let mut new = Self::empty();
173+
unsafe { ffi::mlx_array_set(&mut new.ptr, self.ptr) };
174+
new
180175
}
181176
}

src/mlx/encoder.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ impl ConvSubsampling {
9797
let x = add_bias_nhwc(&x, &self.c6_b);
9898
let x = ops::relu(&x);
9999

100-
// x: (1, T', n_mels/8, 256) → (1, T', 256 * n_mels/8)
100+
// x: (1, T', n_mels/8, 256) in NHWC — transpose to match PyTorch's
101+
// NCHW flatten order: (1, T', 256, n_mels/8) → (1, T', 256*n_mels/8)
102+
let x = ops::transpose(&x, &[0, 1, 3, 2]);
101103
let t_prime = x.dim(1);
102104
let feat = x.dim(2) * x.dim(3);
103105
let x = ops::reshape(&x, &[1, t_prime, feat]);

0 commit comments

Comments
 (0)