99from triton .experimental .gluon import language as ttgl
1010from triton .experimental .gluon .language .nvidia import blackwell
1111from triton .experimental .gluon .language .nvidia .blackwell import mbarrier , tma , TensorMemoryLayout
12+ from triton .experimental .gluon .nvidia .hopper import TensorDescriptor
1213from triton ._filecheck import filecheck_test , run_parser
1314import triton .language as tl
1415from triton ._internal_testing import is_cuda
15- from triton .tools .tensor_descriptor import TensorDescriptor
1616from triton .compiler .errors import CompilationError
1717
1818TARGET_PAT = re .compile ('ttg.target = "[^"]*"' )
@@ -434,8 +434,8 @@ def test_tcgen05_mma(fresh_knobs):
434434
435435
436436@gluon .jit
437- def async_tma_kernel (input_desc , XBLOCK : ttgl .constexpr , smem_layout : ttgl . constexpr ):
438- smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], smem_layout )
437+ def async_tma_kernel (input_desc , XBLOCK : ttgl .constexpr ):
438+ smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc . layout )
439439 bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
440440 mbarrier .init (bar , count = 1 )
441441
@@ -455,25 +455,25 @@ def test_async_tma(fresh_knobs):
455455
456456 input = torch .randn ((1024 , 1024 ), device = "cuda" , dtype = torch .float16 )
457457 XBLOCK = 128
458- input_desc = TensorDescriptor .from_tensor (input , [XBLOCK , XBLOCK ])
459458 shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
459+ input_desc = TensorDescriptor .from_tensor (input , [XBLOCK , XBLOCK ], shared_layout )
460460
461- h = async_tma_kernel .warmup (input_desc , XBLOCK , shared_layout , grid = (1 , ), num_warps = 4 )
461+ h = async_tma_kernel .warmup (input_desc , XBLOCK , grid = (1 , ), num_warps = 4 )
462462 expecttest .assert_expected_inline (
463463 anonymize_ir (h .asm ["source" ]), """\
464464 #loc = loc(unknown)
465465#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
466466#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
467467#smem = #ttg.shared_memory
468468module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
469- tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
469+ tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16, #shared >> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
470470 %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
471471 %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
472472 ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
473473 %c0_i32 = arith.constant 0 : i32 loc(#loc)
474474 %c0_i32_0 = arith.constant 0 : i32 loc(#loc)
475475 %true = arith.constant true loc(#loc)
476- ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
476+ ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16, #shared >>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
477477 %true_1 = arith.constant true loc(#loc)
478478 ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
479479 %c0_i32_2 = arith.constant 0 : i32 loc(#loc)
@@ -482,7 +482,7 @@ def test_async_tma(fresh_knobs):
482482 ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
483483 %c0_i32_4 = arith.constant 0 : i32 loc(#loc)
484484 %c0_i32_5 = arith.constant 0 : i32 loc(#loc)
485- ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<tensor<128x128xf16>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
485+ ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<tensor<128x128xf16, #shared >>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
486486 ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
487487 tt.return loc(#loc)
488488 } loc(#loc)
@@ -491,8 +491,8 @@ def test_async_tma(fresh_knobs):
491491
492492
493493@gluon .jit
494- def async_tma_blackwell_kernel (input_desc , XBLOCK : ttgl .constexpr , smem_layout : ttgl . constexpr ):
495- smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], smem_layout )
494+ def async_tma_blackwell_kernel (input_desc , XBLOCK : ttgl .constexpr ):
495+ smem = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc . layout )
496496 bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
497497 mbarrier .init (bar , count = 1 )
498498
@@ -514,10 +514,10 @@ def test_async_tma_blackwell(fresh_knobs):
514514
515515 input = torch .randn ((1024 , 1024 ), device = "cuda" , dtype = torch .float16 )
516516 XBLOCK = 128
517- input_desc = TensorDescriptor .from_tensor (input , [1 , XBLOCK ])
518517 shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
518+ input_desc = TensorDescriptor .from_tensor (input , [1 , XBLOCK ], shared_layout )
519519
520- h = async_tma_blackwell_kernel .warmup (input_desc , XBLOCK , shared_layout , grid = (1 , ), num_warps = 4 )
520+ h = async_tma_blackwell_kernel .warmup (input_desc , XBLOCK , grid = (1 , ), num_warps = 4 )
521521 expecttest .assert_expected_inline (
522522 anonymize_ir (h .asm ["source" ]), """\
523523 #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -526,22 +526,22 @@ def test_async_tma_blackwell(fresh_knobs):
526526#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
527527#smem = #ttg.shared_memory
528528module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
529- tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
529+ tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16, #shared >> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
530530 %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
531531 %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
532532 ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
533533 %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
534534 %true = arith.constant true loc(#loc)
535535 %c0_i32 = arith.constant 0 : i32 loc(#loc)
536- ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
536+ ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16, #shared >>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
537537 %true_0 = arith.constant true loc(#loc)
538538 ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
539539 %c0_i32_1 = arith.constant 0 : i32 loc(#loc)
540540 %true_2 = arith.constant true loc(#loc)
541541 ttng.wait_barrier %1, %c0_i32_1, %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
542542 ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
543543 %c0_i32_3 = arith.constant 0 : i32 loc(#loc)
544- ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<tensor<1x128xf16>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
544+ ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<tensor<1x128xf16, #shared >>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
545545 ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
546546 tt.return loc(#loc)
547547 } loc(#loc)
0 commit comments