Skip to content

Commit 16e712e

Browse files
authored
[NVPTX] Allow directly storing immediates to improve readability (llvm#145552)
Allow directly storing an immediate instead of requiring that it first be moved into a register. This makes for more compact and readable PTX. An approach similar to this (using a ComplexPattern) this could be used for most PTX instructions to avoid the need for `_[ri]+` variants and boiler-plate.
1 parent 46c8cc7 commit 16e712e

19 files changed

+220
-186
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,20 +1339,18 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
13391339
SDValue Offset, Base;
13401340
SelectADDR(ST->getBasePtr(), Base, Offset);
13411341

1342-
SDValue Ops[] = {Value,
1342+
SDValue Ops[] = {selectPossiblyImm(Value),
13431343
getI32Imm(Ordering, DL),
13441344
getI32Imm(Scope, DL),
13451345
getI32Imm(CodeAddrSpace, DL),
1346-
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
13471346
getI32Imm(ToTypeWidth, DL),
13481347
Base,
13491348
Offset,
13501349
Chain};
13511350

1352-
const MVT::SimpleValueType SourceVT =
1353-
Value.getNode()->getSimpleValueType(0).SimpleTy;
1354-
const std::optional<unsigned> Opcode = pickOpcodeForVT(
1355-
SourceVT, NVPTX::ST_i8, NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64);
1351+
const std::optional<unsigned> Opcode =
1352+
pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i8,
1353+
NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64);
13561354
if (!Opcode)
13571355
return false;
13581356

@@ -1389,7 +1387,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
13891387

13901388
const unsigned NumElts = getLoadStoreVectorNumElts(ST);
13911389

1392-
SmallVector<SDValue, 16> Ops(ST->ops().slice(1, NumElts));
1390+
SmallVector<SDValue, 16> Ops;
1391+
for (auto &V : ST->ops().slice(1, NumElts))
1392+
Ops.push_back(selectPossiblyImm(V));
13931393
SDValue Addr = N->getOperand(NumElts + 1);
13941394
const unsigned ToTypeWidth = TotalWidth / NumElts;
13951395

@@ -1400,9 +1400,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14001400
SelectADDR(Addr, Base, Offset);
14011401

14021402
Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
1403-
getI32Imm(CodeAddrSpace, DL),
1404-
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
1405-
getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});
1403+
getI32Imm(CodeAddrSpace, DL), getI32Imm(ToTypeWidth, DL), Base,
1404+
Offset, Chain});
14061405

14071406
const MVT::SimpleValueType EltVT =
14081407
ST->getOperand(1).getSimpleValueType().SimpleTy;
@@ -2102,6 +2101,19 @@ bool NVPTXDAGToDAGISel::SelectADDR(SDValue Addr, SDValue &Base,
21022101
return true;
21032102
}
21042103

2104+
SDValue NVPTXDAGToDAGISel::selectPossiblyImm(SDValue V) {
2105+
if (V.getOpcode() == ISD::BITCAST)
2106+
V = V.getOperand(0);
2107+
2108+
if (auto *CN = dyn_cast<ConstantSDNode>(V))
2109+
return CurDAG->getTargetConstant(CN->getAPIntValue(), SDLoc(V),
2110+
V.getValueType());
2111+
if (auto *CN = dyn_cast<ConstantFPSDNode>(V))
2112+
return CurDAG->getTargetConstantFP(CN->getValueAPF(), SDLoc(V),
2113+
V.getValueType());
2114+
return V;
2115+
}
2116+
21052117
bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N,
21062118
unsigned int spN) const {
21072119
const Value *Src = nullptr;

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
105105
}
106106

107107
bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
108+
SDValue selectPossiblyImm(SDValue V);
108109

109110
bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
110111

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,18 @@ class OneUse2<SDPatternOperator operator>
184184
class fpimm_pos_inf<ValueType vt>
185185
: FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
186186

187+
188+
189+
// Operands which can hold a Register or an Immediate.
190+
//
191+
// Unfortunately, since most register classes can hold multiple types, we must
192+
// use the 'Any' type for these.
193+
194+
def RI1 : Operand<i1>;
195+
def RI16 : Operand<Any>;
196+
def RI32 : Operand<Any>;
197+
def RI64 : Operand<Any>;
198+
187199
// Utility class to wrap up information about a register and DAG type for more
188200
// convenient iteration and parameterization
189201
class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
@@ -2276,19 +2288,20 @@ let mayLoad=1, hasSideEffects=0 in {
22762288
def LD_i64 : LD<B64>;
22772289
}
22782290

2279-
class ST<NVPTXRegClass regclass>
2291+
class ST<DAGOperand O>
22802292
: NVPTXInst<
22812293
(outs),
2282-
(ins regclass:$src, LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp,
2283-
LdStCode:$Sign, i32imm:$toWidth, ADDR:$addr),
2284-
"st${sem:sem}${scope:scope}${addsp:addsp}.${Sign:sign}$toWidth"
2294+
(ins O:$src,
2295+
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$toWidth,
2296+
ADDR:$addr),
2297+
"st${sem:sem}${scope:scope}${addsp:addsp}.b$toWidth"
22852298
" \t[$addr], $src;", []>;
22862299

22872300
let mayStore=1, hasSideEffects=0 in {
2288-
def ST_i8 : ST<B16>;
2289-
def ST_i16 : ST<B16>;
2290-
def ST_i32 : ST<B32>;
2291-
def ST_i64 : ST<B64>;
2301+
def ST_i8 : ST<RI16>;
2302+
def ST_i16 : ST<RI16>;
2303+
def ST_i32 : ST<RI32>;
2304+
def ST_i64 : ST<RI64>;
22922305
}
22932306

22942307
// The following is used only in and after vector elementizations. Vector
@@ -2324,38 +2337,38 @@ let mayLoad=1, hasSideEffects=0 in {
23242337
defm LDV_i64 : LD_VEC<B64>;
23252338
}
23262339

2327-
multiclass ST_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
2340+
multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
23282341
def _v2 : NVPTXInst<
23292342
(outs),
2330-
(ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
2331-
LdStCode:$addsp, LdStCode:$Sign, i32imm:$fromWidth,
2343+
(ins O:$src1, O:$src2,
2344+
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$fromWidth,
23322345
ADDR:$addr),
2333-
"st${sem:sem}${scope:scope}${addsp:addsp}.v2.${Sign:sign}$fromWidth "
2346+
"st${sem:sem}${scope:scope}${addsp:addsp}.v2.b$fromWidth "
23342347
"\t[$addr], {{$src1, $src2}};", []>;
23352348
def _v4 : NVPTXInst<
23362349
(outs),
2337-
(ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
2338-
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp,
2339-
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
2340-
"st${sem:sem}${scope:scope}${addsp:addsp}.v4.${Sign:sign}$fromWidth "
2350+
(ins O:$src1, O:$src2, O:$src3, O:$src4,
2351+
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$fromWidth,
2352+
ADDR:$addr),
2353+
"st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
23412354
"\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
23422355
if support_v8 then
23432356
def _v8 : NVPTXInst<
23442357
(outs),
2345-
(ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
2346-
regclass:$src5, regclass:$src6, regclass:$src7, regclass:$src8,
2347-
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Sign,
2348-
i32imm:$fromWidth, ADDR:$addr),
2349-
"st${sem:sem}${scope:scope}${addsp:addsp}.v8.${Sign:sign}$fromWidth "
2358+
(ins O:$src1, O:$src2, O:$src3, O:$src4,
2359+
O:$src5, O:$src6, O:$src7, O:$src8,
2360+
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$fromWidth,
2361+
ADDR:$addr),
2362+
"st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
23502363
"\t[$addr], "
23512364
"{{$src1, $src2, $src3, $src4, $src5, $src6, $src7, $src8}};", []>;
23522365
}
23532366

23542367
let mayStore=1, hasSideEffects=0 in {
2355-
defm STV_i8 : ST_VEC<B16>;
2356-
defm STV_i16 : ST_VEC<B16>;
2357-
defm STV_i32 : ST_VEC<B32, support_v8 = true>;
2358-
defm STV_i64 : ST_VEC<B64>;
2368+
defm STV_i8 : ST_VEC<RI16>;
2369+
defm STV_i16 : ST_VEC<RI16>;
2370+
defm STV_i32 : ST_VEC<RI32, support_v8 = true>;
2371+
defm STV_i64 : ST_VEC<RI64>;
23592372
}
23602373

23612374
//---- Conversion ----

llvm/test/CodeGen/NVPTX/access-non-generic.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ define void @nested_const_expr() {
107107
; PTX-LABEL: nested_const_expr(
108108
; store 1 to bitcast(gep(addrspacecast(array), 0, 1))
109109
store i32 1, ptr getelementptr ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i64 0, i64 1), align 4
110-
; PTX: mov.b32 %r1, 1;
111-
; PTX-NEXT: st.shared.b32 [array+4], %r1;
110+
; PTX: st.shared.b32 [array+4], 1;
112111
ret void
113112
}
114113

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,5 +1605,23 @@ define <2 x bfloat> @test_maxnum_v2(<2 x bfloat> %a, <2 x bfloat> %b) {
16051605
ret <2 x bfloat> %r
16061606
}
16071607

1608+
define void @store_bf16(ptr %p1, ptr %p2, bfloat %v) {
1609+
; CHECK-LABEL: store_bf16(
1610+
; CHECK: {
1611+
; CHECK-NEXT: .reg .b16 %rs<2>;
1612+
; CHECK-NEXT: .reg .b64 %rd<3>;
1613+
; CHECK-EMPTY:
1614+
; CHECK-NEXT: // %bb.0:
1615+
; CHECK-NEXT: ld.param.b64 %rd1, [store_bf16_param_0];
1616+
; CHECK-NEXT: ld.param.b16 %rs1, [store_bf16_param_2];
1617+
; CHECK-NEXT: st.b16 [%rd1], %rs1;
1618+
; CHECK-NEXT: ld.param.b64 %rd2, [store_bf16_param_1];
1619+
; CHECK-NEXT: st.b16 [%rd2], 0x3F80;
1620+
; CHECK-NEXT: ret;
1621+
store bfloat %v, ptr %p1
1622+
store bfloat 1.0, ptr %p2
1623+
ret void
1624+
}
1625+
16081626
declare bfloat @llvm.maximum.bf16(bfloat, bfloat)
16091627
declare <2 x bfloat> @llvm.maximum.v2bf16(<2 x bfloat>, <2 x bfloat>)

llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,20 @@ define <2 x bfloat> @test_copysign(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
723723
ret <2 x bfloat> %r
724724
}
725725

726+
define void @test_store_bf16x2(ptr %p1, ptr %p2, <2 x bfloat> %v) {
727+
; CHECK-LABEL: test_store_bf16x2(
728+
; CHECK: {
729+
; CHECK-NEXT: .reg .b32 %r<2>;
730+
; CHECK-NEXT: .reg .b64 %rd<3>;
731+
; CHECK-EMPTY:
732+
; CHECK-NEXT: // %bb.0:
733+
; CHECK-NEXT: ld.param.b64 %rd1, [test_store_bf16x2_param_0];
734+
; CHECK-NEXT: ld.param.b32 %r1, [test_store_bf16x2_param_2];
735+
; CHECK-NEXT: st.b32 [%rd1], %r1;
736+
; CHECK-NEXT: ld.param.b64 %rd2, [test_store_bf16x2_param_1];
737+
; CHECK-NEXT: st.b32 [%rd2], 1065369472;
738+
; CHECK-NEXT: ret;
739+
store <2 x bfloat> %v, ptr %p1
740+
store <2 x bfloat> <bfloat 1.0, bfloat 1.0>, ptr %p2
741+
ret void
742+
}

llvm/test/CodeGen/NVPTX/chain-different-as.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
define i64 @test() nounwind readnone {
55
; CHECK-LABEL: test(
66
; CHECK: {
7-
; CHECK-NEXT: .reg .b64 %rd<4>;
7+
; CHECK-NEXT: .reg .b64 %rd<3>;
88
; CHECK-EMPTY:
99
; CHECK-NEXT: // %bb.0:
1010
; CHECK-NEXT: mov.b64 %rd1, 1;
11-
; CHECK-NEXT: mov.b64 %rd2, 42;
12-
; CHECK-NEXT: st.b64 [%rd1], %rd2;
13-
; CHECK-NEXT: ld.global.b64 %rd3, [%rd1];
14-
; CHECK-NEXT: st.param.b64 [func_retval0], %rd3;
11+
; CHECK-NEXT: st.b64 [%rd1], 42;
12+
; CHECK-NEXT: ld.global.b64 %rd2, [%rd1];
13+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
1514
; CHECK-NEXT: ret;
1615
%addr0 = inttoptr i64 1 to ptr
1716
%addr1 = inttoptr i64 1 to ptr addrspace(1)

llvm/test/CodeGen/NVPTX/demote-vars.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ define void @define_private_global(i64 %val) {
6767
; Also check that the if-then is still here, otherwise we may not be testing
6868
; the "more-than-one-use" part.
6969
; CHECK: st.shared.b64 [private_global_used_more_than_once_in_same_fct],
70-
; CHECK: mov.b64 %[[VAR:.*]], 25
71-
; CHECK: st.shared.b64 [private_global_used_more_than_once_in_same_fct], %[[VAR]]
70+
; CHECK: st.shared.b64 [private_global_used_more_than_once_in_same_fct], 25
7271
define void @define_private_global_more_than_one_use(i64 %val, i1 %cond) {
7372
store i64 %val, ptr addrspace(3) @private_global_used_more_than_once_in_same_fct
7473
br i1 %cond, label %then, label %end

llvm/test/CodeGen/NVPTX/f16x2-instructions.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,5 +2295,25 @@ define <2 x half> @test_uitofp_2xi16_to_2xhalf(<2 x i16> %a) #0 {
22952295
ret <2 x half> %r
22962296
}
22972297

2298+
define void @test_store_2xhalf(ptr %p1, ptr %p2, <2 x half> %v) {
2299+
; CHECK-LABEL: test_store_2xhalf(
2300+
; CHECK: {
2301+
; CHECK-NEXT: .reg .b32 %r<2>;
2302+
; CHECK-NEXT: .reg .b64 %rd<3>;
2303+
; CHECK-EMPTY:
2304+
; CHECK-NEXT: // %bb.0:
2305+
; CHECK-NEXT: ld.param.b32 %r1, [test_store_2xhalf_param_2];
2306+
; CHECK-NEXT: ld.param.b64 %rd2, [test_store_2xhalf_param_1];
2307+
; CHECK-NEXT: ld.param.b64 %rd1, [test_store_2xhalf_param_0];
2308+
; CHECK-NEXT: st.b32 [%rd1], %r1;
2309+
; CHECK-NEXT: st.b32 [%rd2], 1006648320;
2310+
; CHECK-NEXT: ret;
2311+
store <2 x half> %v, ptr %p1
2312+
store <2 x half> <half 1.0, half 1.0>, ptr %p2
2313+
ret void
2314+
}
2315+
2316+
2317+
22982318
attributes #0 = { nounwind }
22992319
attributes #1 = { "unsafe-fp-math" = "true" }

llvm/test/CodeGen/NVPTX/i1-load-lower.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@ target triple = "nvptx-nvidia-cuda"
1010
define void @foo() {
1111
; CHECK-LABEL: foo(
1212
; CHECK: .reg .pred %p<2>;
13-
; CHECK: .reg .b16 %rs<4>;
13+
; CHECK: .reg .b16 %rs<3>;
1414
; CHECK-EMPTY:
1515
; CHECK: ld.global.b8 %rs1, [i1g];
1616
; CHECK: and.b16 %rs2, %rs1, 1;
1717
; CHECK: setp.ne.b16 %p1, %rs2, 0;
1818
; CHECK: @%p1 bra $L__BB0_2;
19-
; CHECK: mov.b16 %rs3, 1;
20-
; CHECK: st.global.b8 [i1g], %rs3;
19+
; CHECK: st.global.b8 [i1g], 1;
2120
; CHECK: ret;
2221
%tmp = load i1, ptr addrspace(1) @i1g, align 2
2322
br i1 %tmp, label %if.end, label %if.then

0 commit comments

Comments
 (0)