@@ -68,7 +68,9 @@ struct ConvSubsampling {
6868
6969impl ConvSubsampling {
7070 fn load ( weights : & Weights , prefix : & str ) -> Result < Self > {
71- let w = |n : & str | -> Result < Tensor > { Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) ) } ;
71+ let w = |n : & str | -> Result < Tensor > {
72+ Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) )
73+ } ;
7274 Ok ( Self {
7375 c0_w : w ( "conv.0.weight" ) ?,
7476 c0_b : w ( "conv.0.bias" ) ?,
@@ -169,7 +171,9 @@ struct FeedForward {
169171
170172impl FeedForward {
171173 fn load ( weights : & Weights , prefix : & str ) -> Result < Self > {
172- let w = |n : & str | -> Result < Tensor > { Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) ) } ;
174+ let w = |n : & str | -> Result < Tensor > {
175+ Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) )
176+ } ;
173177 Ok ( Self {
174178 l1_w : w ( "linear1.weight" ) ?,
175179 l1_b : w ( "linear1.bias" ) ?,
@@ -204,7 +208,9 @@ struct ConformerConv {
204208
205209impl ConformerConv {
206210 fn load ( weights : & Weights , prefix : & str , d_model : i64 ) -> Result < Self > {
207- let w = |n : & str | -> Result < Tensor > { Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) ) } ;
211+ let w = |n : & str | -> Result < Tensor > {
212+ Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) )
213+ } ;
208214 Ok ( Self {
209215 pw1_w : w ( "pointwise_conv1.weight" ) ?,
210216 pw1_b : w ( "pointwise_conv1.bias" ) ?,
@@ -237,8 +243,14 @@ impl ConformerConv {
237243
238244 // Depthwise conv
239245 let pad = ( kernel_size - 1 ) / 2 ;
240- let x =
241- x. conv1d ( & self . dw_w , Some ( & self . dw_b ) , & [ 1 ] , & [ pad] , & [ 1 ] , self . d_model ) ;
246+ let x = x. conv1d (
247+ & self . dw_w ,
248+ Some ( & self . dw_b ) ,
249+ & [ 1 ] ,
250+ & [ pad] ,
251+ & [ 1 ] ,
252+ self . d_model ,
253+ ) ;
242254
243255 // BatchNorm (eval mode)
244256 let x = batch_norm_eval ( & x, & self . bn_w , & self . bn_b , & self . bn_rm , & self . bn_rv ) ;
@@ -277,7 +289,9 @@ struct RelPosAttn {
277289impl RelPosAttn {
278290 fn load ( weights : & Weights , prefix : & str , n_heads : i64 , d_model : i64 ) -> Result < Self > {
279291 let d_k = d_model / n_heads;
280- let w = |n : & str | -> Result < Tensor > { Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) ) } ;
292+ let w = |n : & str | -> Result < Tensor > {
293+ Ok ( weights. get ( & format ! ( "{}{}" , prefix, n) ) ?. shallow_clone ( ) )
294+ } ;
281295 Ok ( Self {
282296 q_w : w ( "linear_q.weight" ) ?,
283297 q_b : w ( "linear_q.bias" ) ?,
@@ -309,19 +323,22 @@ impl RelPosAttn {
309323 fn forward ( & self , x : & Tensor , pos_emb : & Tensor ) -> Tensor {
310324 let ( b, t, _) = x. size3 ( ) . unwrap ( ) ;
311325
312- let reshape = |z : & Tensor | -> Tensor {
313- z. view ( [ b, t, self . n_heads , self . d_k ] ) . transpose ( 1 , 2 )
314- } ;
326+ let reshape =
327+ |z : & Tensor | -> Tensor { z. view ( [ b, t, self . n_heads , self . d_k ] ) . transpose ( 1 , 2 ) } ;
315328
316329 let q = reshape ( & linear ( x, & self . q_w , & self . q_b ) ) ;
317330 let k = reshape ( & linear ( x, & self . k_w , & self . k_b ) ) ;
318331 let v = reshape ( & linear ( x, & self . v_w , & self . v_b ) ) ;
319332
320333 // pos_emb: (1, 2T-1, d_model) → (1, 2T-1, H, d_k) → (1, H, 2T-1, d_k)
321334 let n_pos = pos_emb. size ( ) [ 1 ] ;
322- let p = linear ( pos_emb, & self . pos_w , & Tensor :: zeros ( & [ 1 ] , ( Kind :: Float , x. device ( ) ) ) )
323- . view ( [ 1 , n_pos, self . n_heads , self . d_k ] )
324- . transpose ( 1 , 2 ) ;
335+ let p = linear (
336+ pos_emb,
337+ & self . pos_w ,
338+ & Tensor :: zeros ( & [ 1 ] , ( Kind :: Float , x. device ( ) ) ) ,
339+ )
340+ . view ( [ 1 , n_pos, self . n_heads , self . d_k ] )
341+ . transpose ( 1 , 2 ) ;
325342
326343 // pos_bias_u/v: (n_heads, d_k) → (1, n_heads, 1, d_k) for broadcasting
327344 let u = self . pos_bias_u . view ( [ 1 , self . n_heads , 1 , self . d_k ] ) ;
@@ -384,11 +401,7 @@ impl ConformerLayer {
384401 d_model,
385402 ) ?,
386403 norm_conv : norm ( "norm_conv" ) ?,
387- conv : ConformerConv :: load (
388- weights,
389- & format ! ( "{}conv." , prefix) ,
390- d_model,
391- ) ?,
404+ conv : ConformerConv :: load ( weights, & format ! ( "{}conv." , prefix) , d_model) ?,
392405 norm_ff2 : norm ( "norm_feed_forward2" ) ?,
393406 ff2 : FeedForward :: load ( weights, & format ! ( "{}feed_forward2." , prefix) ) ?,
394407 norm_out : norm ( "norm_out" ) ?,
@@ -397,13 +410,27 @@ impl ConformerLayer {
397410
398411 fn forward ( & self , x : & Tensor , pos_emb : & Tensor ) -> Tensor {
399412 // FF1 (½-scaled)
400- let x = x + 0.5 * self . ff1 . forward ( & layer_norm ( x, & self . norm_ff1 . 0 , & self . norm_ff1 . 1 ) ) ;
413+ let x = x + 0.5
414+ * self
415+ . ff1
416+ . forward ( & layer_norm ( x, & self . norm_ff1 . 0 , & self . norm_ff1 . 1 ) ) ;
401417 // Self-attention
402- let x = & x + self . self_attn . forward ( & layer_norm ( & x, & self . norm_self_att . 0 , & self . norm_self_att . 1 ) , pos_emb) ;
418+ let x = & x
419+ + self . self_attn . forward (
420+ & layer_norm ( & x, & self . norm_self_att . 0 , & self . norm_self_att . 1 ) ,
421+ pos_emb,
422+ ) ;
403423 // Conformer conv
404- let x = & x + self . conv . forward ( & layer_norm ( & x, & self . norm_conv . 0 , & self . norm_conv . 1 ) ) ;
424+ let x = & x
425+ + self
426+ . conv
427+ . forward ( & layer_norm ( & x, & self . norm_conv . 0 , & self . norm_conv . 1 ) ) ;
405428 // FF2 (½-scaled)
406- let x = & x + 0.5 * self . ff2 . forward ( & layer_norm ( & x, & self . norm_ff2 . 0 , & self . norm_ff2 . 1 ) ) ;
429+ let x = & x
430+ + 0.5
431+ * self
432+ . ff2
433+ . forward ( & layer_norm ( & x, & self . norm_ff2 . 0 , & self . norm_ff2 . 1 ) ) ;
407434 // Final norm
408435 layer_norm ( & x, & self . norm_out . 0 , & self . norm_out . 1 )
409436 }
@@ -437,15 +464,14 @@ impl ConformerEncoder {
437464 }
438465
439466 // Encoder→decoder projection (Linear 1280 → 1024)
440- let ( enc_dec_proj_w, enc_dec_proj_b) =
441- if let ( Ok ( w) , Ok ( b) ) = (
442- weights. get ( "encoder_decoder_proj.weight" ) ,
443- weights. get ( "encoder_decoder_proj.bias" ) ,
444- ) {
445- ( Some ( w. shallow_clone ( ) ) , Some ( b. shallow_clone ( ) ) )
446- } else {
447- ( None , None )
448- } ;
467+ let ( enc_dec_proj_w, enc_dec_proj_b) = if let ( Ok ( w) , Ok ( b) ) = (
468+ weights. get ( "encoder_decoder_proj.weight" ) ,
469+ weights. get ( "encoder_decoder_proj.bias" ) ,
470+ ) {
471+ ( Some ( w. shallow_clone ( ) ) , Some ( b. shallow_clone ( ) ) )
472+ } else {
473+ ( None , None )
474+ } ;
449475
450476 Ok ( Self {
451477 pre_encode,
0 commit comments