File tree Expand file tree Collapse file tree 4 files changed +444
-47
lines changed
Expand file tree Collapse file tree 4 files changed +444
-47
lines changed Original file line number Diff line number Diff line change @@ -55,15 +55,9 @@ torch::Tensor fused_moe(
5555 int topk_group,
5656 double route_scale,
5757 int start_expert_id,
58- int block_n,
5958 bool avg_moe,
60- const std::optional<torch::Tensor>& class_reduce_weight,
61- const std::optional<torch::Tensor>& class_expert_id,
6259 const std::optional<torch::List<int64_t >>& w1_quant_flag,
63- const std::optional<torch::List<int64_t >>& w2_quant_flag,
64- int world_size,
65- int shared_expert_num,
66- const std::string& parallel_mode) {
60+ const std::optional<torch::List<int64_t >>& w2_quant_flag) {
6761 auto dtype = hidden_states.dtype ();
6862 auto ori_input_shape = hidden_states.sizes ();
6963
Original file line number Diff line number Diff line change @@ -160,15 +160,9 @@ torch::Tensor fused_moe(
160160 int topk_group,
161161 double route_scale,
162162 int start_expert_id,
163- int block_n,
164163 bool avg_moe,
165- const std::optional<torch::Tensor>& class_reduce_weight,
166- const std::optional<torch::Tensor>& class_expert_id,
167164 const std::optional<torch::List<int64_t >>& w1_quant_flag,
168- const std::optional<torch::List<int64_t >>& w2_quant_flag,
169- int world_size,
170- int shared_expert_num,
171- const std::string& parallel_mode);
165+ const std::optional<torch::List<int64_t >>& w2_quant_flag);
172166
173167std::tuple<torch::Tensor, torch::Tensor> scaled_quantize (
174168 const torch::Tensor& x,
Original file line number Diff line number Diff line change @@ -231,15 +231,9 @@ torch::Tensor fused_moe(FusedMoEParams& params) {
231231 params.topk_group ,
232232 params.route_scale ,
233233 params.start_expert_id ,
234- params.block_n ,
235234 params.avg_moe ,
236- params.class_reduce_weight ,
237- params.class_expert_id ,
238235 params.w1_quant_flag ,
239- params.w2_quant_flag ,
240- params.world_size ,
241- params.shared_expert_num ,
242- params.parallel_mode );
236+ params.w2_quant_flag );
243237#elif defined(USE_CUDA)
244238 LOG (FATAL) << " fused_moe for cuda not implemented" ;
245239#else
You can’t perform that action at this time.
0 commit comments