@@ -185,96 +185,7 @@ class B2bGemm {
185185 SmemAccumulator
186186 >::B2bGemmKernel;
187187
188- // / Argument structure
189- struct Arguments {
190-
191- //
192- // Data members
193- //
194-
195- GemmUniversalMode mode;
196- GemmCoord problem_size_0;
197- GemmCoord problem_size_1;
198- TensorRef<ElementA const , LayoutA> ref_A0;
199- TensorRef<ElementB const , LayoutB> ref_B0;
200- TensorRef<ElementC const , LayoutC> ref_C0;
201- TensorRef<ElementScaleBias const , LayoutScaleBias> ref_Scale0;
202- TensorRef<ElementScaleBias const , LayoutScaleBias> ref_Bias0;
203- TensorRef<ElementB const , LayoutB> ref_B1;
204- TensorRef<ElementC const , LayoutC> ref_C1;
205- TensorRef<ElementC, LayoutC> ref_D1;
206- int64_t batch_stride_A0;
207- int64_t batch_stride_B0;
208- int64_t batch_stride_B1;
209- int64_t batch_stride_C1;
210- int64_t batch_stride_D1;
211- int64_t batch_stride_Bias0;
212- int64_t batch_stride_Scale0;
213- typename EpilogueOutputOp0::Params epilogue0;
214- typename EpilogueOutputOp1::Params epilogue1;
215- int batch_count;
216-
217- //
218- // Methods
219- //
220-
221- // / Default ctor
222- CUTLASS_HOST_DEVICE
223- Arguments (): mode(mode), problem_size_0(0 , 0 , 0 ), problem_size_1(0 , 0 , 0 ), batch_count(1 ) {
224-
225- }
226-
227- // / Constructs an Arguments structure
228- CUTLASS_HOST_DEVICE
229- Arguments (
230- GemmUniversalMode mode_,
231- GemmCoord problem_size_0_,
232- GemmCoord problem_size_1_,
233- TensorRef<ElementA const , LayoutA> ref_A0_,
234- TensorRef<ElementB const , LayoutB> ref_B0_,
235- TensorRef<ElementC const , LayoutC> ref_C0_,
236- TensorRef<ElementScaleBias const , LayoutScaleBias> ref_Scale0_,
237- TensorRef<ElementScaleBias const , LayoutScaleBias> ref_Bias0_,
238- TensorRef<ElementB const , LayoutB> ref_B1_,
239- TensorRef<ElementC const , LayoutC> ref_C1_,
240- TensorRef<ElementC, LayoutC> ref_D1_,
241- int64_t batch_stride_A0_,
242- int64_t batch_stride_B0_,
243- int64_t batch_stride_B1_,
244- int64_t batch_stride_C1_,
245- int64_t batch_stride_D1_,
246- int64_t batch_stride_Bias0_,
247- int64_t batch_stride_Scale0_,
248- typename EpilogueOutputOp0::Params epilogue0_ =
249- typename EpilogueOutputOp0::Params (),
250- typename EpilogueOutputOp1::Params epilogue1_ =
251- typename EpilogueOutputOp1::Params(),
252- int batch_count_ = 1
253- ):
254- mode(mode_),
255- problem_size_0(problem_size_0_),
256- problem_size_1(problem_size_1_),
257- ref_A0(ref_A0_),
258- ref_B0(ref_B0_),
259- ref_C0(ref_C0_),
260- ref_Scale0(ref_Scale0_),
261- ref_Bias0(ref_Bias0_),
262- ref_B1(ref_B1_),
263- ref_C1(ref_C1_),
264- ref_D1(ref_D1_),
265- batch_stride_A0(batch_stride_A0_),
266- batch_stride_B0(batch_stride_B0_),
267- batch_stride_B1(batch_stride_B1_),
268- batch_stride_C1(batch_stride_C1_),
269- batch_stride_D1(batch_stride_D1_),
270- batch_stride_Bias0(batch_stride_Bias0_),
271- batch_stride_Scale0(batch_stride_Scale0_),
272- epilogue0(epilogue0_),
273- epilogue1(epilogue1_),
274- batch_count(batch_count_) {
275-
276- }
277- };
188+ using Arguments = typename B2bGemmKernel::Arguments;
278189
279190private:
280191
0 commit comments