@@ -1915,13 +1915,44 @@ def kernel(a_ptr, b_ptr, c_ptr, out_ptr):
19151915 ttgl .store (out_ptr + offs , out )
19161916
19171917 a = torch .rand ((B , B ), dtype = torch .float32 , device = "cuda" )
1918- b = torch .ones ((B , B ), dtype = torch .float32 , device = "cuda" )
1918+ b = torch .rand ((B , B ), dtype = torch .float32 , device = "cuda" )
19191919 c = torch .rand ((B , B ), dtype = torch .float32 , device = "cuda" )
19201920 out = torch .empty ((B , B ), dtype = torch .float32 , device = "cuda" )
19211921 kernel [(1 , )](a , b , c , out )
19221922 torch .testing .assert_close (out , torch .addmm (c , a , b ), atol = 1e-2 , rtol = 1e-2 )
19231923
19241924
1925+ def test_dot3d_fma ():
1926+ torch .manual_seed (42 )
1927+ B = ttgl .constexpr (32 )
1928+ BATCH = ttgl .constexpr (8 )
1929+ threads_per_warp = ttgl .constexpr (THREADS_PER_WARP )
1930+
1931+ @gluon .jit
1932+ def kernel (a_ptr , b_ptr , c_ptr , out_ptr ):
1933+ layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 , 1 ], [1 , threads_per_warp , 1 ], [ttgl .num_warps (), 1 , 1 ],
1934+ [2 , 1 , 0 ])
1935+ lhs_layout : ttgl .constexpr = ttgl .DotOperandLayout (parent = layout , operand_index = 0 , k_width = 0 )
1936+ rhs_layout : ttgl .constexpr = ttgl .DotOperandLayout (parent = layout , operand_index = 1 , k_width = 0 )
1937+
1938+ offs_b = ttgl .arange (0 , BATCH , layout = ttgl .SliceLayout (1 , ttgl .SliceLayout (2 , layout )))[:, None , None ]
1939+ offs_m = ttgl .arange (0 , B , layout = ttgl .SliceLayout (0 , ttgl .SliceLayout (2 , layout )))[None , :, None ]
1940+ offs_n = ttgl .arange (0 , B , layout = ttgl .SliceLayout (0 , ttgl .SliceLayout (1 , layout )))[None , None , :]
1941+ offs = offs_b * B * B + offs_m * B + offs_n
1942+ a = ttgl .convert_layout (ttgl .load (a_ptr + offs ), lhs_layout )
1943+ b = ttgl .convert_layout (ttgl .load (b_ptr + offs ), rhs_layout )
1944+ c = ttgl .load (c_ptr + offs )
1945+ out = ttgl .dot_fma (a , b , c )
1946+ ttgl .store (out_ptr + offs , out )
1947+
1948+ a = torch .rand ((BATCH , B , B ), dtype = torch .float32 , device = "cuda" )
1949+ b = torch .rand ((BATCH , B , B ), dtype = torch .float32 , device = "cuda" )
1950+ c = torch .rand ((BATCH , B , B ), dtype = torch .float32 , device = "cuda" )
1951+ out = torch .empty ((BATCH , B , B ), dtype = torch .float32 , device = "cuda" )
1952+ kernel [(1 , )](a , b , c , out )
1953+ torch .testing .assert_close (out , torch .matmul (a , b ) + c , atol = 1e-2 , rtol = 1e-2 )
1954+
1955+
19251956@gluon .jit
19261957def kernel_auto_layout_constant (threads_per_warp : ttgl .constexpr ):
19271958 BLOCK : ttgl .constexpr = 16
0 commit comments