@@ -1082,6 +1082,161 @@ let Predicates = [hasPTX<70>, hasSM<80>] in {
10821082 "mbarrier.pending_count.b64",
10831083 [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>;
10841084}
1085+
1086+ class MBAR_UTIL<string op, string scope,
1087+ string space = "", string sem = "",
1088+ bit tl = 0, bit parity = 0> {
1089+ // The mbarrier instructions in PTX ISA are of the general form:
1090+ // mbarrier.op.semantics.scope.space.b64 arg1, arg2 ...
1091+ // where:
1092+ // op -> arrive, expect_tx, complete_tx, arrive.expect_tx etc.
1093+ // semantics -> acquire, release, relaxed (default depends on the op)
1094+ // scope -> cta or cluster (default is cta-scope)
1095+ // space -> shared::cta or shared::cluster (default is shared::cta)
1096+ //
1097+ // The 'semantics' and 'scope' go together. If one is specified,
1098+ // then the other _must_ be specified. For example:
1099+ // (A) mbarrier.arrive <args> (valid, release and cta are default)
1100+ // (B) mbarrier.arrive.release.cta <args> (valid, sem/scope mentioned explicitly)
1101+ // (C) mbarrier.arrive.release <args> (invalid, needs scope)
1102+ // (D) mbarrier.arrive.cta <args> (invalid, needs order)
1103+ //
1104+ // Wherever possible, we prefer form (A) to (B) since it is available
1105+ // from early PTX versions. In most cases, explicitly specifying the
1106+ // scope requires a later version of PTX.
1107+ string _scope_asm = !cond(
1108+ !eq(scope, "scope_cluster") : "cluster",
1109+ !eq(scope, "scope_cta") : !if(!empty(sem), "", "cta"),
1110+ true : scope);
1111+ string _space_asm = !cond(
1112+ !eq(space, "space_cta") : "shared",
1113+ !eq(space, "space_cluster") : "shared::cluster",
1114+ true : space);
1115+
1116+ string _parity = !if(parity, "parity", "");
1117+ string asm_str = StrJoin<".", ["mbarrier", op, _parity,
1118+ sem, _scope_asm, _space_asm, "b64"]>.ret;
1119+
1120+ string _intr_suffix = StrJoin<"_", [!subst(".", "_", op), _parity,
1121+ !if(tl, "tl", ""),
1122+ sem, scope, space]>.ret;
1123+ string intr_name = "int_nvvm_mbarrier_" # _intr_suffix;
1124+
1125+ // Predicate checks:
1126+ // These are used only for the "test_wait/try_wait" variants as they
1127+ // have evolved since sm80 and are complex. The predicates for the
1128+ // remaining instructions are straightforward and have already been
1129+ // applied directly.
1130+ Predicate _sm_pred = !cond(!or(
1131+ !eq(op, "try_wait"),
1132+ !eq(scope, "scope_cluster"),
1133+ !eq(sem, "relaxed")) : hasSM<90>,
1134+ true : hasSM<80>);
1135+ Predicate _ptx_pred = !cond(
1136+ !eq(sem, "relaxed") : hasPTX<86>,
1137+ !ne(_scope_asm, "") : hasPTX<80>,
1138+ !eq(op, "try_wait") : hasPTX<78>,
1139+ parity : hasPTX<71>,
1140+ true : hasPTX<70>);
1141+ list<Predicate> preds = [_ptx_pred, _sm_pred];
1142+ }
1143+
1144+ foreach op = ["expect_tx", "complete_tx"] in {
1145+ foreach scope = ["scope_cta", "scope_cluster"] in {
1146+ foreach space = ["space_cta", "space_cluster"] in {
1147+ defvar intr = !cast<Intrinsic>(MBAR_UTIL<op, scope, space>.intr_name);
1148+ defvar suffix = StrJoin<"_", [op, scope, space]>.ret;
1149+ def mbar_ # suffix : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$tx_count),
1150+ MBAR_UTIL<op, scope, space, "relaxed">.asm_str,
1151+ [(intr addr:$addr, i32:$tx_count)]>,
1152+ Requires<[hasPTX<80>, hasSM<90>]>;
1153+ } // space
1154+ } // scope
1155+ } // op
1156+
1157+ multiclass MBAR_ARR_INTR<string op, string scope, string sem,
1158+ list<Predicate> pred = []> {
1159+ // When either of sem or scope is non-default, both have to
1160+ // be explicitly specified. So, explicitly state that
1161+ // sem is `release` when scope is `cluster`.
1162+ defvar asm_sem = !if(!and(!empty(sem), !eq(scope, "scope_cluster")),
1163+ "release", sem);
1164+
1165+ defvar asm_cta = MBAR_UTIL<op, scope, "space_cta", asm_sem>.asm_str;
1166+ defvar intr_cta = !cast<Intrinsic>(MBAR_UTIL<op, scope,
1167+ "space_cta", sem>.intr_name);
1168+
1169+ defvar asm_cluster = MBAR_UTIL<op, scope, "space_cluster", asm_sem>.asm_str;
1170+ defvar intr_cluster = !cast<Intrinsic>(MBAR_UTIL<op, scope,
1171+ "space_cluster", sem>.intr_name);
1172+
1173+ def _CTA : NVPTXInst<(outs B64:$state),
1174+ (ins ADDR:$addr, B32:$tx_count),
1175+ asm_cta # " $state, [$addr], $tx_count;",
1176+ [(set i64:$state, (intr_cta addr:$addr, i32:$tx_count))]>,
1177+ Requires<pred>;
1178+ def _CLUSTER : NVPTXInst<(outs),
1179+ (ins ADDR:$addr, B32:$tx_count),
1180+ asm_cluster # " _, [$addr], $tx_count;",
1181+ [(intr_cluster addr:$addr, i32:$tx_count)]>,
1182+ Requires<pred>;
1183+ }
1184+ foreach op = ["arrive", "arrive.expect_tx",
1185+ "arrive_drop", "arrive_drop.expect_tx"] in {
1186+ foreach scope = ["scope_cta", "scope_cluster"] in {
1187+ defvar suffix = !subst(".", "_", op) # scope;
1188+ defm mbar_ # suffix # _release : MBAR_ARR_INTR<op, scope, "", [hasPTX<80>, hasSM<90>]>;
1189+ defm mbar_ # suffix # _relaxed : MBAR_ARR_INTR<op, scope, "relaxed", [hasPTX<86>, hasSM<90>]>;
1190+ } // scope
1191+ } // op
1192+
1193+ multiclass MBAR_WAIT_INTR<string op, string scope, string sem, bit time_limit> {
1194+ // When either of sem or scope is non-default, both have to
1195+ // be explicitly specified. So, explicitly state that the
1196+ // semantics is `acquire` when the scope is `cluster`.
1197+ defvar asm_sem = !if(!and(!empty(sem), !eq(scope, "scope_cluster")),
1198+ "acquire", sem);
1199+
1200+ defvar asm_parity = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1201+ time_limit, 1>.asm_str;
1202+ defvar pred_parity = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1203+ time_limit, 1>.preds;
1204+ defvar intr_parity = !cast<Intrinsic>(MBAR_UTIL<op, scope, "space_cta",
1205+ sem, time_limit, 1>.intr_name);
1206+
1207+ defvar asm_state = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1208+ time_limit>.asm_str;
1209+ defvar pred_state = MBAR_UTIL<op, scope, "space_cta", asm_sem,
1210+ time_limit>.preds;
1211+ defvar intr_state = !cast<Intrinsic>(MBAR_UTIL<op, scope, "space_cta",
1212+ sem, time_limit>.intr_name);
1213+
1214+ defvar ins_tl_dag = !if(time_limit, (ins B32:$tl), (ins));
1215+ defvar tl_suffix = !if(time_limit, ", $tl;", ";");
1216+ defvar intr_state_dag = !con((intr_state addr:$addr, i64:$state),
1217+ !if(time_limit, (intr_state i32:$tl), (intr_state)));
1218+ defvar intr_parity_dag = !con((intr_parity addr:$addr, i32:$phase),
1219+ !if(time_limit, (intr_parity i32:$tl), (intr_parity)));
1220+
1221+ def _STATE : NVPTXInst<(outs B1:$res), !con((ins ADDR:$addr, B64:$state), ins_tl_dag),
1222+ asm_state # " $res, [$addr], $state" # tl_suffix,
1223+ [(set i1:$res, intr_state_dag)]>,
1224+ Requires<pred_state>;
1225+ def _PARITY : NVPTXInst<(outs B1:$res), !con((ins ADDR:$addr, B32:$phase), ins_tl_dag),
1226+ asm_parity # " $res, [$addr], $phase" # tl_suffix,
1227+ [(set i1:$res, intr_parity_dag)]>,
1228+ Requires<pred_parity>;
1229+ }
1230+ foreach op = ["test_wait", "try_wait"] in {
1231+ foreach scope = ["scope_cta", "scope_cluster"] in {
1232+ foreach time_limit = !if(!eq(op, "try_wait"), [true, false], [false]) in {
1233+ defvar suffix = StrJoin<"_", [op, scope, !if(time_limit, "tl", "")]>.ret;
1234+ defm mbar_ # suffix # "_acquire" : MBAR_WAIT_INTR<op, scope, "", time_limit>;
1235+ defm mbar_ # suffix # "_relaxed" : MBAR_WAIT_INTR<op, scope, "relaxed", time_limit>;
1236+ } // time_limit
1237+ } // scope
1238+ } // op
1239+
10851240//-----------------------------------
10861241// Math Functions
10871242//-----------------------------------
0 commit comments