Skip to content

Commit b990378

Browse files
committed
reformat code
Signed-off-by: CaranLic <[email protected]>
1 parent a4db963 commit b990378

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

tests/ut/quantization/test_w8a8.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -731,11 +731,11 @@ def test_softmax_scoring(self, mock_topk):
731731
0).contiguous())
732732

733733
weights, ids = select_experts(hidden_states=self.hidden_states,
734-
router_logits=self.router_logits,
735-
top_k=self.top_k,
736-
use_grouped_topk=False,
737-
renormalize=False,
738-
scoring_func="softmax")
734+
router_logits=self.router_logits,
735+
top_k=self.top_k,
736+
use_grouped_topk=False,
737+
renormalize=False,
738+
scoring_func="softmax")
739739

740740
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
741741
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -744,11 +744,11 @@ def test_sigmoid_scoring(self):
744744
"""Test sigmoid scoring function"""
745745

746746
weights, ids = select_experts(hidden_states=self.hidden_states,
747-
router_logits=self.router_logits,
748-
top_k=self.top_k,
749-
use_grouped_topk=False,
750-
renormalize=False,
751-
scoring_func="sigmoid")
747+
router_logits=self.router_logits,
748+
top_k=self.top_k,
749+
use_grouped_topk=False,
750+
renormalize=False,
751+
scoring_func="sigmoid")
752752

753753
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
754754
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -772,12 +772,12 @@ def test_grouped_topk(self, mock_topk):
772772
dtype=torch.long))
773773

774774
weights, ids = select_experts(hidden_states=self.hidden_states,
775-
router_logits=self.router_logits,
776-
top_k=self.top_k,
777-
use_grouped_topk=True,
778-
renormalize=False,
779-
topk_group=4,
780-
num_expert_group=2)
775+
router_logits=self.router_logits,
776+
top_k=self.top_k,
777+
use_grouped_topk=True,
778+
renormalize=False,
779+
topk_group=4,
780+
num_expert_group=2)
781781

782782
mock_topk.assert_called()
783783
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))

vllm_ascend/ops/moe/experts_selector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def _select_experts_with_fusion_ops(
176176
norm_type = 1
177177
if e_score_correction_bias is not None and \
178178
e_score_correction_bias.dtype != router_logits.dtype:
179-
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
179+
e_score_correction_bias = e_score_correction_bias.to(
180+
router_logits.dtype)
180181
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
181182
router_logits,
182183
k=top_k,
@@ -267,4 +268,3 @@ def _native_select_experts(
267268
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
268269

269270
return topk_weights, topk_ids
270-

0 commit comments

Comments
 (0)