Skip to content

Commit e044ece

Browse files
juntaoclaude
andcommitted
Fix MLX layer_norm calls: add missing eps parameter
ops::layer_norm takes 4 arguments (x, weight, bias, eps) but all 10 call sites in encoder.rs and decoder.rs were passing only 3. Add the standard eps=1e-5 to all calls. Signed-off-by: Michael Yuan <michael@secondstate.io> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d5d6db9 commit e044ece

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/mlx/decoder.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ impl DecoderLayer {
186186
) -> (Array, Array, Array) {
187187
// --- Self-attention ---
188188
let (nw1, nb1) = &self.norm1;
189-
let normed = ops::layer_norm(hidden, nw1, nb1);
189+
let normed = ops::layer_norm(hidden, nw1, nb1, 1e-5);
190190
let (k_new, v_new) = self.self_attn.project_kv(&normed);
191191

192192
let (k_full, v_full) = match (self_k_cache, self_v_cache) {
@@ -202,13 +202,13 @@ impl DecoderLayer {
202202

203203
// --- Cross-attention ---
204204
let (nw2, nb2) = &self.norm2;
205-
let normed2 = ops::layer_norm(&hidden, nw2, nb2);
205+
let normed2 = ops::layer_norm(&hidden, nw2, nb2, 1e-5);
206206
let cross_out = self.cross_attn.forward(&normed2, cross_k, cross_v, None);
207207
let hidden = ops::add(&hidden, &cross_out);
208208

209209
// --- FFN ---
210210
let (nw3, nb3) = &self.norm3;
211-
let normed3 = ops::layer_norm(&hidden, nw3, nb3);
211+
let normed3 = ops::layer_norm(&hidden, nw3, nb3, 1e-5);
212212
let ffn_out = self.ffn.forward(&normed3);
213213
let hidden = ops::add(&hidden, &ffn_out);
214214

@@ -341,7 +341,7 @@ impl TransformerDecoder {
341341
let pe = ops::reshape(&pe, &[1, 1, self.hidden]);
342342

343343
let x = ops::add(&emb, &pe);
344-
let x = ops::layer_norm(&x, &self.emb_norm_w, &self.emb_norm_b);
344+
let x = ops::layer_norm(&x, &self.emb_norm_w, &self.emb_norm_b, 1e-5);
345345

346346
let mut new_kv: Vec<(Option<Array>, Option<Array>)> = Vec::with_capacity(self.layers.len());
347347
let mut hidden = x;
@@ -362,7 +362,7 @@ impl TransformerDecoder {
362362
}
363363

364364
// Final layer norm + classification head
365-
let hidden = ops::layer_norm(&hidden, &self.final_ln_w, &self.final_ln_b);
365+
let hidden = ops::layer_norm(&hidden, &self.final_ln_w, &self.final_ln_b, 1e-5);
366366
// Squeeze T dim: (1, 1, hidden) → (1, hidden)
367367
let hidden = ops::squeeze(&hidden, &[1]);
368368
let logits = ops::linear(&hidden, &self.head_w, &self.head_b); // (1, vocab)

src/mlx/encoder.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -513,23 +513,23 @@ impl ConformerLayer {
513513

514514
fn forward(&self, x: &Array, pos_emb: &Array) -> Array {
515515
let (nw1, nb1) = &self.norm_ff1;
516-
let ff1_out = self.ff1.forward(&ops::layer_norm(x, nw1, nb1));
516+
let ff1_out = self.ff1.forward(&ops::layer_norm(x, nw1, nb1, 1e-5));
517517
let x = ops::add(x, &ops::scale(&ff1_out, 0.5));
518518

519519
let (nw2, nb2) = &self.norm_self_att;
520-
let attn_out = self.self_attn.forward(&ops::layer_norm(&x, nw2, nb2), pos_emb);
520+
let attn_out = self.self_attn.forward(&ops::layer_norm(&x, nw2, nb2, 1e-5), pos_emb);
521521
let x = ops::add(&x, &attn_out);
522522

523523
let (nw3, nb3) = &self.norm_conv;
524-
let conv_out = self.conv.forward(&ops::layer_norm(&x, nw3, nb3));
524+
let conv_out = self.conv.forward(&ops::layer_norm(&x, nw3, nb3, 1e-5));
525525
let x = ops::add(&x, &conv_out);
526526

527527
let (nw4, nb4) = &self.norm_ff2;
528-
let ff2_out = self.ff2.forward(&ops::layer_norm(&x, nw4, nb4));
528+
let ff2_out = self.ff2.forward(&ops::layer_norm(&x, nw4, nb4, 1e-5));
529529
let x = ops::add(&x, &ops::scale(&ff2_out, 0.5));
530530

531531
let (nw5, nb5) = &self.norm_out;
532-
ops::layer_norm(&x, nw5, nb5)
532+
ops::layer_norm(&x, nw5, nb5, 1e-5)
533533
}
534534
}
535535

0 commit comments

Comments
 (0)