@@ -731,11 +731,11 @@ def test_softmax_scoring(self, mock_topk):
731
731
0 ).contiguous ())
732
732
733
733
weights , ids = select_experts (hidden_states = self .hidden_states ,
734
- router_logits = self .router_logits ,
735
- top_k = self .top_k ,
736
- use_grouped_topk = False ,
737
- renormalize = False ,
738
- scoring_func = "softmax" )
734
+ router_logits = self .router_logits ,
735
+ top_k = self .top_k ,
736
+ use_grouped_topk = False ,
737
+ renormalize = False ,
738
+ scoring_func = "softmax" )
739
739
740
740
self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
741
741
self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
@@ -744,11 +744,11 @@ def test_sigmoid_scoring(self):
744
744
"""Test sigmoid scoring function"""
745
745
746
746
weights , ids = select_experts (hidden_states = self .hidden_states ,
747
- router_logits = self .router_logits ,
748
- top_k = self .top_k ,
749
- use_grouped_topk = False ,
750
- renormalize = False ,
751
- scoring_func = "sigmoid" )
747
+ router_logits = self .router_logits ,
748
+ top_k = self .top_k ,
749
+ use_grouped_topk = False ,
750
+ renormalize = False ,
751
+ scoring_func = "sigmoid" )
752
752
753
753
self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
754
754
self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
@@ -772,12 +772,12 @@ def test_grouped_topk(self, mock_topk):
772
772
dtype = torch .long ))
773
773
774
774
weights , ids = select_experts (hidden_states = self .hidden_states ,
775
- router_logits = self .router_logits ,
776
- top_k = self .top_k ,
777
- use_grouped_topk = True ,
778
- renormalize = False ,
779
- topk_group = 4 ,
780
- num_expert_group = 2 )
775
+ router_logits = self .router_logits ,
776
+ top_k = self .top_k ,
777
+ use_grouped_topk = True ,
778
+ renormalize = False ,
779
+ topk_group = 4 ,
780
+ num_expert_group = 2 )
781
781
782
782
mock_topk .assert_called ()
783
783
self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
0 commit comments