@@ -842,3 +842,124 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
842842 tt.return
843843 }
844844}
845+
846+ // -----
847+
848+ #blocked = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
849+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
850+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
851+ #smem = #ttg.shared_memory
852+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , colStride = 1 >
853+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [32 , 0 ], [64 , 0 ], [0 , 4 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ]], warp = [[0 , 0 ], [0 , 0 ]], block = []}>
854+ #tmem_scales = #ttng.tensor_memory_scales_encoding <>
855+ module attributes {" ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" } {
856+ // CHECK-LABEL: @nested_loop_yes_double_buffer
857+ tt.func @nested_loop_yes_double_buffer (%lb: i32 , %ub: i32 , %step: i32 , %ptr0: !tt.ptr <i32 >) {
858+ %true = arith.constant true
859+ %false = arith.constant false
860+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
861+ // CHECK: [[BUF:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem,
862+ // CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[BUF]]
863+ %res , %tok = ttng.tmem_alloc : () ->(!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
864+ %toka = ttng.tmem_store %cst , %res [%tok ], %true : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
865+ // CHECK: scf.for
866+ %tok0 = scf.for %iv0 = %lb to %ub step %step iter_args (%tok1 = %toka ) -> (!ttg.async.token ) : i32 {
867+ %tok1a = ttng.tmem_store %cst , %res [%tok1 ], %true {ttg.partition = array<i32 : 2 >} : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
868+ // CHECK: scf.for
869+ %useD , %tok4 = scf.for %iv = %lb to %ub step %step iter_args (%useD = %false , %tok2 = %tok1a ) -> (i1 , !ttg.async.token ) : i32 {
870+ %sA = " load1" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <128 x64 xf32 , #shared , #smem >
871+ %sB = " load2" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <64 x128 xf32 , #shared , #smem >
872+ %tok3 = ttng.tc_gen5_mma %sA , %sB , %res [%tok2 ], %useD , %true {ttg.partition = array<i32 : 2 >} : !ttg.memdesc <128 x64 xf32 , #shared , #smem >, !ttg.memdesc <64 x128 xf32 , #shared , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
873+ scf.yield {ttg.partition = array<i32 : 1 , 2 >} %true , %tok3 : i1 , !ttg.async.token
874+ } {ttg.partition = array<i32 : 1 , 2 >, ttg.partition.outputs = [array <i32 : 2 >, array <i32 : 2 >]}
875+ %val , %tok5 = ttng.tmem_load %res [%tok4 ] {ttg.partition = array<i32 : 0 >} : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
876+ " use" (%val ) {ttg.partition = array<i32 : 0 >} : (tensor <128 x128 xf32 , #blocked >) -> ()
877+ scf.yield {ttg.partition = array<i32 : 0 , 1 , 2 >} %tok5 : !ttg.async.token
878+ } {tt.warp_specialize , ttg.partition = array<i32 : 0 , 1 , 2 >, ttg.partition.outputs = [array <i32 : 2 >], ttg.warp_specialize.tag = 0 : i32 }
879+ tt.return
880+ }
881+
882+ // CHECK-LABEL: @nested_loop_no_double_buffer
883+ tt.func @nested_loop_no_double_buffer (%lb: i32 , %ub: i32 , %step: i32 , %ptr0: !tt.ptr <i32 >) {
884+ %true = arith.constant true
885+ %false = arith.constant false
886+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
887+ // CHECK: [[BUF:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem,
888+ // CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[BUF]]
889+ %res , %tok = ttng.tmem_alloc : () ->(!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
890+ %toka = ttng.tmem_store %cst , %res [%tok ], %true : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
891+ // CHECK: scf.for
892+ %tok0 = scf.for %iv0 = %lb to %ub step %step iter_args (%tok1 = %toka ) -> (!ttg.async.token ) : i32 {
893+ %tok1a = ttng.tmem_store %cst , %res [%tok1 ], %true {ttg.partition = array<i32 : 0 >} : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
894+ // CHECK: scf.for
895+ %useD , %tok4 = scf.for %iv = %lb to %ub step %step iter_args (%useD = %false , %tok2 = %tok1a ) -> (i1 , !ttg.async.token ) : i32 {
896+ %sA = " load1" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <128 x64 xf32 , #shared , #smem >
897+ %sB = " load2" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <64 x128 xf32 , #shared , #smem >
898+ %tok3 = ttng.tc_gen5_mma %sA , %sB , %res [%tok2 ], %useD , %true {ttg.partition = array<i32 : 2 >} : !ttg.memdesc <128 x64 xf32 , #shared , #smem >, !ttg.memdesc <64 x128 xf32 , #shared , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
899+ scf.yield {ttg.partition = array<i32 : 1 , 2 >} %true , %tok3 : i1 , !ttg.async.token
900+ } {ttg.partition = array<i32 : 1 , 2 >, ttg.partition.outputs = [array <i32 : 2 >, array <i32 : 2 >]}
901+ %val , %tok5 = ttng.tmem_load %res [%tok4 ] {ttg.partition = array<i32 : 0 >} : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
902+ " use" (%val ) {ttg.partition = array<i32 : 0 >} : (tensor <128 x128 xf32 , #blocked >) -> ()
903+ scf.yield {ttg.partition = array<i32 : 0 , 1 , 2 >} %tok5 : !ttg.async.token
904+ } {tt.warp_specialize , ttg.partition = array<i32 : 0 , 1 , 2 >, ttg.partition.outputs = [array <i32 : 0 >], ttg.warp_specialize.tag = 0 : i32 }
905+ tt.return
906+ }
907+
908+ // CHECK-LABEL: @nested_loop_yes_double_buffer_scaled
909+ tt.func @nested_loop_yes_double_buffer_scaled (%lb: i32 , %ub: i32 , %step: i32 , %ptr0: !tt.ptr <i32 >,
910+ %scalesA: tensor <128 x8 xi8 , #linear >, %scalesB: tensor <128 x8 xi8 , #linear >) {
911+ %true = arith.constant true
912+ %false = arith.constant false
913+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
914+ // CHECK: [[BUF:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem,
915+ // CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[BUF]]
916+ %res , %tok = ttng.tmem_alloc : () ->(!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
917+ %toka = ttng.tmem_store %cst , %res [%tok ], %true : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
918+ %lhs_scales = ttng.tmem_alloc %scalesA: (tensor <128 x8 xi8 , #linear >) -> !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >
919+ %rhs_scales = ttng.tmem_alloc %scalesB : (tensor <128 x8 xi8 , #linear >) -> !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >
920+ // CHECK: scf.for
921+ %tok0 = scf.for %iv0 = %lb to %ub step %step iter_args (%tok1 = %toka ) -> (!ttg.async.token ) : i32 {
922+ %tok1a = ttng.tmem_store %cst , %res [%tok1 ], %true {ttg.partition = array<i32 : 2 >} : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
923+ // CHECK: scf.for
924+ %useD , %tok4 = scf.for %iv = %lb to %ub step %step iter_args (%useD = %false , %tok2 = %tok1a ) -> (i1 , !ttg.async.token ) : i32 {
925+ %sA = " load1" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <128 x64 xf32 , #shared , #smem >
926+ %sB = " load2" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <64 x128 xf32 , #shared , #smem >
927+ %tok3 = ttng.tc_gen5_mma_scaled %sA , %sB , %res [%tok2 ], %lhs_scales , %rhs_scales , %useD , %true lhs = e4m3 rhs = e4m3 {ttg.partition = array<i32 : 2 >} : !ttg.memdesc <128 x64 xf32 , #shared , #smem >, !ttg.memdesc <64 x128 xf32 , #shared , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >, !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >
928+ scf.yield {ttg.partition = array<i32 : 1 , 2 >} %true , %tok3 : i1 , !ttg.async.token
929+ } {ttg.partition = array<i32 : 1 , 2 >, ttg.partition.outputs = [array <i32 : 2 >, array <i32 : 2 >]}
930+ %val , %tok5 = ttng.tmem_load %res [%tok4 ] {ttg.partition = array<i32 : 0 >} : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
931+ " use" (%val ) {ttg.partition = array<i32 : 0 >} : (tensor <128 x128 xf32 , #blocked >) -> ()
932+ scf.yield {ttg.partition = array<i32 : 0 , 1 , 2 >} %tok5 : !ttg.async.token
933+ } {tt.warp_specialize , ttg.partition = array<i32 : 0 , 1 , 2 >, ttg.partition.outputs = [array <i32 : 2 >], ttg.warp_specialize.tag = 0 : i32 }
934+ tt.return
935+ }
936+
937+ // CHECK-LABEL: @nested_loop_no_double_buffer_scaled
938+ tt.func @nested_loop_no_double_buffer_scaled (%lb: i32 , %ub: i32 , %step: i32 , %ptr0: !tt.ptr <i32 >,
939+ %scalesA: tensor <128 x8 xi8 , #linear >, %scalesB: tensor <128 x8 xi8 , #linear >) {
940+ %true = arith.constant true
941+ %false = arith.constant false
942+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x256 xf32 , #blocked >
943+ // CHECK: [[BUF:%.*]] = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x256xf32, #tmem,
944+ // CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[BUF]]
945+ %res , %tok = ttng.tmem_alloc : () ->(!ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
946+ %toka = ttng.tmem_store %cst , %res [%tok ], %true : tensor <128 x256 xf32 , #blocked > -> !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >
947+ %lhs_scales = ttng.tmem_alloc %scalesA : (tensor <128 x8 xi8 , #linear >) -> !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >
948+ %rhs_scales = ttng.tmem_alloc %scalesB : (tensor <128 x8 xi8 , #linear >) -> !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >
949+ // CHECK: scf.for
950+ %tok0 = scf.for %iv0 = %lb to %ub step %step iter_args (%tok1 = %toka ) -> (!ttg.async.token ) : i32 {
951+ %tok1a = ttng.tmem_store %cst , %res [%tok1 ], %true {ttg.partition = array<i32 : 2 >} : tensor <128 x256 xf32 , #blocked > -> !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >
952+ // CHECK: scf.for
953+ %useD , %tok4 = scf.for %iv = %lb to %ub step %step iter_args (%useD = %false , %tok2 = %tok1a ) -> (i1 , !ttg.async.token ) : i32 {
954+ %sA = " load1" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <128 x64 xf32 , #shared , #smem >
955+ %sB = " load2" (%iv ) {ttg.partition = array<i32 : 1 >} : (i32 ) -> !ttg.memdesc <64 x256 xf32 , #shared , #smem >
956+ %tok3 = ttng.tc_gen5_mma_scaled %sA , %sB , %res [%tok2 ], %lhs_scales , %rhs_scales , %useD , %true lhs = e4m3 rhs = e4m3 {ttg.partition = array<i32 : 2 >} : !ttg.memdesc <128 x64 xf32 , #shared , #smem >, !ttg.memdesc <64 x256 xf32 , #shared , #smem >, !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >, !ttg.memdesc <128 x8 xi8 , #tmem_scales , #ttng.tensor_memory >
957+ scf.yield {ttg.partition = array<i32 : 1 , 2 >} %true , %tok3 : i1 , !ttg.async.token
958+ } {ttg.partition = array<i32 : 1 , 2 >, ttg.partition.outputs = [array <i32 : 2 >, array <i32 : 2 >]}
959+ %val , %tok5 = ttng.tmem_load %res [%tok4 ] {ttg.partition = array<i32 : 0 >} : !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x256 xf32 , #blocked >
960+ " use" (%val ) {ttg.partition = array<i32 : 0 >} : (tensor <128 x256 xf32 , #blocked >) -> ()
961+ scf.yield {ttg.partition = array<i32 : 0 , 1 , 2 >} %tok5 : !ttg.async.token
962+ } {tt.warp_specialize , ttg.partition = array<i32 : 0 , 1 , 2 >, ttg.partition.outputs = [array <i32 : 2 >], ttg.warp_specialize.tag = 0 : i32 }
963+ tt.return
964+ }
965+ }
0 commit comments