@@ -31,72 +31,12 @@ Module Contents
3131
3232.. py :function :: gemm_v1(A, B, C, transpose_A = False , transpose_B = False , policy = GemmWarpPolicy.Square, clear_accum = False , k_pack = 1 , wg_wait = 0 , mbar = None )
3333
34- Perform a General Matrix Multiplication (GEMM) operation.
35-
36- This function computes C = A @ B where A and B can optionally be transposed.
37- The operation supports various warp policies and accumulation modes.
38-
39- :param A: First input matrix
40- :type A: Union[tir.Buffer, tir.Var]
41- :param B: Second input matrix
42- :type B: Union[tir.Buffer, tir.Var]
43- :param C: Output matrix for results
44- :type C: Union[tir.Buffer, tir.Var]
45- :param transpose_A: Whether to transpose matrix A. Defaults to False.
46- :type transpose_A: bool, optional
47- :param transpose_B: Whether to transpose matrix B. Defaults to False.
48- :type transpose_B: bool, optional
49- :param policy: Warp execution policy. Defaults to GemmWarpPolicy.Square.
50- :type policy: GemmWarpPolicy, optional
51- :param clear_accum: Whether to clear accumulator before computation. Defaults to False.
52- :type clear_accum: bool, optional
53- :param k_pack: Number of k dimensions packed into a single warp. Defaults to 1.
54- :type k_pack: int, optional
55- :param wg_wait: Warp group wait count. Defaults to 0.
56- On hopper it is equivalent to `wgmma.wait_group.sync.aligned <wg_wait> ` if wg_wait is not -1
57- On sm100, `wg_wait ` can only be 0 or -1. `mbarrier_wait(TCGEN5MMA barrier) ` will be appended if wg_wait is 0.
58- :type wg_wait: int, optional
59- :param mbar: mbarrier for TCGEN5MMA synchronization
60- :type mbar: tir.Buffer, optional
61-
62- :returns: A handle to the GEMM operation
63- :rtype: tir.Call
64-
65- :raises AssertionError: If the K dimensions of matrices A and B don't match
34+ GEMM v1: use op tl.gemm.
6635
6736
6837.. py :function :: gemm_v2(A, B, C, transpose_A = False , transpose_B = False , policy = GemmWarpPolicy.Square, clear_accum = False , k_pack = 1 , wg_wait = 0 , mbar = None )
6938
70- Perform a General Matrix Multiplication (GEMM) operation.
71-
72- This function computes C = A @ B where A and B can optionally be transposed.
73- The operation supports various warp policies and accumulation modes.
74-
75- :param A: First input matrix
76- :type A: Union[tir.Buffer, tir.Var]
77- :param B: Second input matrix
78- :type B: Union[tir.Buffer, tir.Var]
79- :param C: Output matrix for results
80- :type C: Union[tir.Buffer, tir.Var]
81- :param transpose_A: Whether to transpose matrix A. Defaults to False.
82- :type transpose_A: bool, optional
83- :param transpose_B: Whether to transpose matrix B. Defaults to False.
84- :type transpose_B: bool, optional
85- :param policy: Warp execution policy. Defaults to GemmWarpPolicy.Square.
86- :type policy: GemmWarpPolicy, optional
87- :param clear_accum: Whether to clear accumulator before computation. Defaults to False.
88- :type clear_accum: bool, optional
89- :param k_pack: Number of k dimensions packed into a single warp. Defaults to 1.
90- :type k_pack: int, optional
91- :param wg_wait: Warp group wait count. Defaults to 0.
92- :type wg_wait: int, optional
93- :param mbar: mbarrier for TCGEN5MMA synchronization
94- :type mbar: tir.Buffer, optional
95-
96- :returns: A handle to the GEMM operation
97- :rtype: tir.Call
98-
99- :raises AssertionError: If the K dimensions of matrices A and B don't match
39+ GEMM v2: use op tl.gemm_py.
10040
10141
10242.. py :data :: gemm
0 commit comments