Skip to content

Commit 87349d3

Browse files
authored
Add grouped b2b GEMM (PaddlePaddle#970)
1 parent fde824a commit 87349d3

15 files changed

+1644
-107
lines changed

examples/13_two_tensor_op_fusion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ endforeach()
6464
foreach(FUSION_GEMM_EXAMPLE
6565
fused_two_gemms_f16_sm75_rf
6666
fused_two_gemms_f16_sm75_shmem
67+
fused_two_gemms_grouped_f16_sm80_rf
6768
fused_two_gemms_f16_sm80_rf
6869
fused_two_gemms_f16_sm80_shmem
6970
fused_two_gemms_s8_sm75_rf

examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h

Lines changed: 450 additions & 0 deletions
Large diffs are not rendered by default.

examples/13_two_tensor_op_fusion/device/b2b_gemm.h

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -185,96 +185,7 @@ class B2bGemm {
185185
SmemAccumulator
186186
>::B2bGemmKernel;
187187

188-
/// Argument structure
189-
struct Arguments {
190-
191-
//
192-
// Data members
193-
//
194-
195-
GemmUniversalMode mode;
196-
GemmCoord problem_size_0;
197-
GemmCoord problem_size_1;
198-
TensorRef<ElementA const, LayoutA> ref_A0;
199-
TensorRef<ElementB const, LayoutB> ref_B0;
200-
TensorRef<ElementC const, LayoutC> ref_C0;
201-
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
202-
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
203-
TensorRef<ElementB const, LayoutB> ref_B1;
204-
TensorRef<ElementC const, LayoutC> ref_C1;
205-
TensorRef<ElementC, LayoutC> ref_D1;
206-
int64_t batch_stride_A0;
207-
int64_t batch_stride_B0;
208-
int64_t batch_stride_B1;
209-
int64_t batch_stride_C1;
210-
int64_t batch_stride_D1;
211-
int64_t batch_stride_Bias0;
212-
int64_t batch_stride_Scale0;
213-
typename EpilogueOutputOp0::Params epilogue0;
214-
typename EpilogueOutputOp1::Params epilogue1;
215-
int batch_count;
216-
217-
//
218-
// Methods
219-
//
220-
221-
/// Default ctor
222-
CUTLASS_HOST_DEVICE
223-
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
224-
225-
}
226-
227-
/// Constructs an Arguments structure
228-
CUTLASS_HOST_DEVICE
229-
Arguments(
230-
GemmUniversalMode mode_,
231-
GemmCoord problem_size_0_,
232-
GemmCoord problem_size_1_,
233-
TensorRef<ElementA const, LayoutA> ref_A0_,
234-
TensorRef<ElementB const, LayoutB> ref_B0_,
235-
TensorRef<ElementC const, LayoutC> ref_C0_,
236-
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
237-
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
238-
TensorRef<ElementB const, LayoutB> ref_B1_,
239-
TensorRef<ElementC const, LayoutC> ref_C1_,
240-
TensorRef<ElementC, LayoutC> ref_D1_,
241-
int64_t batch_stride_A0_,
242-
int64_t batch_stride_B0_,
243-
int64_t batch_stride_B1_,
244-
int64_t batch_stride_C1_,
245-
int64_t batch_stride_D1_,
246-
int64_t batch_stride_Bias0_,
247-
int64_t batch_stride_Scale0_,
248-
typename EpilogueOutputOp0::Params epilogue0_ =
249-
typename EpilogueOutputOp0::Params(),
250-
typename EpilogueOutputOp1::Params epilogue1_ =
251-
typename EpilogueOutputOp1::Params(),
252-
int batch_count_ = 1
253-
):
254-
mode(mode_),
255-
problem_size_0(problem_size_0_),
256-
problem_size_1(problem_size_1_),
257-
ref_A0(ref_A0_),
258-
ref_B0(ref_B0_),
259-
ref_C0(ref_C0_),
260-
ref_Scale0(ref_Scale0_),
261-
ref_Bias0(ref_Bias0_),
262-
ref_B1(ref_B1_),
263-
ref_C1(ref_C1_),
264-
ref_D1(ref_D1_),
265-
batch_stride_A0(batch_stride_A0_),
266-
batch_stride_B0(batch_stride_B0_),
267-
batch_stride_B1(batch_stride_B1_),
268-
batch_stride_C1(batch_stride_C1_),
269-
batch_stride_D1(batch_stride_D1_),
270-
batch_stride_Bias0(batch_stride_Bias0_),
271-
batch_stride_Scale0(batch_stride_Scale0_),
272-
epilogue0(epilogue0_),
273-
epilogue1(epilogue1_),
274-
batch_count(batch_count_) {
275-
276-
}
277-
};
188+
using Arguments = typename B2bGemmKernel::Arguments;
278189

279190
private:
280191

0 commit comments

Comments
 (0)