Skip to content

Commit 92c6074

Browse files
committed
[SPIRV] Add legalization for long vectors
This patch introduces the necessary infrastructure to legalize vector operations on vectors that are longer than what the SPIR-V target supports. For instance, shaders only support vectors up to 4 elements. The legalization is done by splitting the long vectors into smaller vectors of a legal size. Specifically, this patch does the following: - Introduces `vectorElementCountIsGreaterThan` and `vectorElementCountIsLessThanOrEqualTo` legality predicates. - Adds legalization rules for `G_SHUFFLE_VECTOR`, `G_EXTRACT_VECTOR_ELT`, `G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and `G_UNMERGE_VALUES`. - Handles `G_BITCAST` of long vectors by converting them to `@llvm.spv.bitcast` intrinsics which are then legalized. - Updates `selectUnmergeValues` to handle extraction of both scalars and vectors from a larger vector, using `OpCompositeExtract` and `OpVectorShuffle` respectively. Fixes: llvm#165444
1 parent e7bcd80 commit 92c6074

File tree

6 files changed

+299
-24
lines changed

6 files changed

+299
-24
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
314314
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
315315
unsigned Size);
316316

317+
/// True iff the specified type index is a vector with a number of elements
318+
/// that's greater than the given size.
319+
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
320+
unsigned Size);
321+
322+
/// True iff the specified type index is a vector with a number of elements
323+
/// that's less than or equal to the given size.
324+
LLVM_ABI LegalityPredicate
325+
vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);
326+
317327
/// True iff the specified type index is a scalar or a vector with an element
318328
/// type that's wider than the given size.
319329
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,

llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
155155
};
156156
}
157157

158+
LegalityPredicate
159+
LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
160+
unsigned Size) {
161+
162+
return [=](const LegalityQuery &Query) {
163+
const LLT QueryTy = Query.Types[TypeIdx];
164+
return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
165+
};
166+
}
167+
168+
LegalityPredicate
169+
LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
170+
unsigned Size) {
171+
172+
return [=](const LegalityQuery &Query) {
173+
const LLT QueryTy = Query.Types[TypeIdx];
174+
return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
175+
};
176+
}
177+
158178
LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
159179
unsigned Size) {
160180
return [=](const LegalityQuery &Query) {

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,33 +1781,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
17811781
unsigned ArgI = I.getNumOperands() - 1;
17821782
Register SrcReg =
17831783
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
1784-
SPIRVType *DefType =
1784+
SPIRVType *SrcType =
17851785
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
1786-
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
1786+
if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
17871787
report_fatal_error(
17881788
"cannot select G_UNMERGE_VALUES with a non-vector argument");
17891789

17901790
SPIRVType *ScalarType =
1791-
GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
1791+
GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
17921792
MachineBasicBlock &BB = *I.getParent();
17931793
bool Res = false;
1794+
unsigned CurrentIndex = 0;
17941795
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
17951796
Register ResVReg = I.getOperand(i).getReg();
17961797
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
17971798
if (!ResType) {
1798-
// There was no "assign type" actions, let's fix this now
1799-
ResType = ScalarType;
1799+
LLT ResLLT = MRI->getType(ResVReg);
1800+
assert(ResLLT.isValid());
1801+
if (ResLLT.isVector()) {
1802+
ResType = GR.getOrCreateSPIRVVectorType(
1803+
ScalarType, ResLLT.getNumElements(), I, TII);
1804+
} else {
1805+
ResType = ScalarType;
1806+
}
18001807
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
1801-
MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
18021808
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
18031809
}
1804-
auto MIB =
1805-
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1806-
.addDef(ResVReg)
1807-
.addUse(GR.getSPIRVTypeID(ResType))
1808-
.addUse(SrcReg)
1809-
.addImm(static_cast<int64_t>(i));
1810-
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1810+
1811+
if (ResType->getOpcode() == SPIRV::OpTypeVector) {
1812+
Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
1813+
auto MIB =
1814+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
1815+
.addDef(ResVReg)
1816+
.addUse(GR.getSPIRVTypeID(ResType))
1817+
.addUse(SrcReg)
1818+
.addUse(UndefReg);
1819+
unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
1820+
for (unsigned j = 0; j < NumElements; ++j) {
1821+
MIB.addImm(CurrentIndex + j);
1822+
}
1823+
CurrentIndex += NumElements;
1824+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1825+
} else {
1826+
auto MIB =
1827+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
1828+
.addDef(ResVReg)
1829+
.addUse(GR.getSPIRVTypeID(ResType))
1830+
.addUse(SrcReg)
1831+
.addImm(CurrentIndex);
1832+
CurrentIndex++;
1833+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
1834+
}
18111835
}
18121836
return Res;
18131837
}

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 159 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414
#include "SPIRV.h"
1515
#include "SPIRVGlobalRegistry.h"
1616
#include "SPIRVSubtarget.h"
17+
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
1718
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
1819
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
1920
#include "llvm/CodeGen/MachineInstr.h"
2021
#include "llvm/CodeGen/MachineRegisterInfo.h"
2122
#include "llvm/CodeGen/TargetOpcodes.h"
23+
#include "llvm/IR/IntrinsicsSPIRV.h"
24+
#include "llvm/Support/Debug.h"
25+
#include "llvm/Support/MathExtras.h"
2226

2327
using namespace llvm;
2428
using namespace llvm::LegalizeActions;
2529
using namespace llvm::LegalityPredicates;
2630

31+
#define DEBUG_TYPE "spirv-legalizer"
32+
2733
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
2834
return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
2935
const LLT Ty = Query.Types[TypeIdx];
@@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
101107
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
102108
v16s8, v16s16, v16s32, v16s64};
103109

110+
auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
111+
v3s1, v3s8, v3s16, v3s32, v3s64,
112+
v4s1, v4s8, v4s16, v4s32, v4s64};
113+
104114
auto allScalarsAndVectors = {
105115
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
106116
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
126136

127137
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
128138

139+
auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
140+
129141
bool IsExtendedInts =
130142
ST.canUseExtension(
131143
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
@@ -148,14 +160,65 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
148160
return IsExtendedInts && Ty.isValid();
149161
};
150162

151-
for (auto Opc : getTypeFoldingSupportedOpcodes())
152-
getActionDefinitionsBuilder(Opc).custom();
163+
uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
153164

154-
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
165+
for (auto Opc : getTypeFoldingSupportedOpcodes()) {
166+
if (Opc != G_EXTRACT_VECTOR_ELT)
167+
getActionDefinitionsBuilder(Opc).custom();
168+
}
155169

156-
// TODO: add proper rules for vectors legalization.
157-
getActionDefinitionsBuilder(
158-
{G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
170+
getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
171+
172+
getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
173+
.legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
174+
.moreElementsToNextPow2(0)
175+
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
176+
.moreElementsToNextPow2(1)
177+
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
178+
.alwaysLegal();
179+
180+
getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
181+
.legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize))
182+
.moreElementsToNextPow2(1)
183+
.fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
184+
LegalizeMutations::changeElementCountTo(
185+
1, ElementCount::getFixed(MaxVectorSize)))
186+
.custom();
187+
188+
// Illegal G_UNMERGE_VALUES instructions should be handled
189+
// during the combine phase.
190+
getActionDefinitionsBuilder(G_BUILD_VECTOR)
191+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
192+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
193+
LegalizeMutations::changeElementCountTo(
194+
0, ElementCount::getFixed(MaxVectorSize)));
195+
196+
// When entering the legalizer, there should be no G_BITCAST instructions.
197+
// They should all be calls to the `spv_bitcast` intrinsic. The call to
198+
// the intrinsic will be converted to a G_BITCAST during legalization if
199+
// the vectors are not legal. After using the rules to legalize a G_BITCAST,
200+
// we turn it back into a call to the intrinsic with a custom rule to avoid
201+
// potential machine verifier failures.
202+
getActionDefinitionsBuilder(G_BITCAST)
203+
.moreElementsToNextPow2(0)
204+
.moreElementsToNextPow2(1)
205+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
206+
LegalizeMutations::changeElementCountTo(
207+
0, ElementCount::getFixed(MaxVectorSize)))
208+
.lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
209+
.custom();
210+
211+
getActionDefinitionsBuilder(G_CONCAT_VECTORS)
212+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
213+
.moreElementsToNextPow2(0)
214+
.lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
215+
.alwaysLegal();
216+
217+
getActionDefinitionsBuilder(G_SPLAT_VECTOR)
218+
.legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
219+
.moreElementsToNextPow2(0)
220+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
221+
LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
159222
.alwaysLegal();
160223

161224
// Vector Reduction Operations
@@ -164,17 +227,18 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
164227
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
165228
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
166229
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
167-
.legalFor(allVectors)
230+
.legalFor(allowedVectorTypes)
168231
.scalarize(1)
169232
.lower();
170233

171234
getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
172235
.scalarize(2)
173236
.lower();
174237

175-
// Merge/Unmerge
176-
// TODO: add proper legalization rules.
177-
getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
238+
// Illegal G_UNMERGE_VALUES instructions should be handled
239+
// during the combine phase.
240+
getActionDefinitionsBuilder(G_UNMERGE_VALUES)
241+
.legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
178242

179243
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
180244
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
@@ -228,7 +292,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
228292
all(typeInSet(0, allPtrsScalarsAndVectors),
229293
typeInSet(1, allPtrsScalarsAndVectors)));
230294

231-
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
295+
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
296+
.legalFor({s1})
297+
.legalFor(allFloatAndIntScalarsAndPtrs)
298+
.legalFor(allowedVectorTypes)
299+
.moreElementsToNextPow2(0)
300+
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
301+
LegalizeMutations::changeElementCountTo(
302+
0, ElementCount::getFixed(MaxVectorSize)));
232303

233304
getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
234305

@@ -287,6 +358,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
287358
// Pointer-handling.
288359
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
289360

361+
getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
362+
290363
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
291364
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
292365

@@ -374,6 +447,11 @@ bool SPIRVLegalizerInfo::legalizeCustom(
374447
default:
375448
// TODO: implement legalization for other opcodes.
376449
return true;
450+
case TargetOpcode::G_BITCAST:
451+
return legalizeBitcast(Helper, MI);
452+
case TargetOpcode::G_INTRINSIC:
453+
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
454+
return legalizeIntrinsic(Helper, MI);
377455
case TargetOpcode::G_IS_FPCLASS:
378456
return legalizeIsFPClass(Helper, MI, LocObserver);
379457
case TargetOpcode::G_ICMP: {
@@ -400,6 +478,76 @@ bool SPIRVLegalizerInfo::legalizeCustom(
400478
}
401479
}
402480

481+
bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
482+
MachineInstr &MI) const {
483+
LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
484+
485+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
486+
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
487+
const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
488+
489+
auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
490+
if (IntrinsicID == Intrinsic::spv_bitcast) {
491+
LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
492+
Register DstReg = MI.getOperand(0).getReg();
493+
Register SrcReg = MI.getOperand(2).getReg();
494+
LLT DstTy = MRI.getType(DstReg);
495+
LLT SrcTy = MRI.getType(SrcReg);
496+
497+
int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
498+
499+
bool DstNeedsLegalization = false;
500+
bool SrcNeedsLegalization = false;
501+
502+
if (DstTy.isVector()) {
503+
if (DstTy.getNumElements() > 4 &&
504+
!isPowerOf2_32(DstTy.getNumElements())) {
505+
DstNeedsLegalization = true;
506+
}
507+
508+
if (DstTy.getNumElements() > MaxVectorSize) {
509+
DstNeedsLegalization = true;
510+
}
511+
}
512+
513+
if (SrcTy.isVector()) {
514+
if (SrcTy.getNumElements() > 4 &&
515+
!isPowerOf2_32(SrcTy.getNumElements())) {
516+
SrcNeedsLegalization = true;
517+
}
518+
519+
if (SrcTy.getNumElements() > MaxVectorSize) {
520+
SrcNeedsLegalization = true;
521+
}
522+
}
523+
524+
// If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
525+
// allow using the generic legalization rules.
526+
if (DstNeedsLegalization || SrcNeedsLegalization) {
527+
LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
528+
MIRBuilder.buildBitcast(DstReg, SrcReg);
529+
MI.eraseFromParent();
530+
}
531+
return true;
532+
}
533+
return true;
534+
}
535+
536+
bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
537+
MachineInstr &MI) const {
538+
// Once the G_BITCAST is using vectors that are allowed, we turn it back into
539+
// an spv_bitcast to avoid verifier problems when the register types are the
540+
// same for the source and the result. Note that the SPIR-V types associated
541+
// with the bitcast can be different even if the register types are the same.
542+
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
543+
Register DstReg = MI.getOperand(0).getReg();
544+
Register SrcReg = MI.getOperand(1).getReg();
545+
SmallVector<Register, 1> DstRegs = {DstReg};
546+
MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
547+
MI.eraseFromParent();
548+
return true;
549+
}
550+
403551
// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
404552
// to ensure that all instructions created during the lowering have SPIR-V types
405553
// assigned to them.

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
2929
public:
3030
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
3131
LostDebugLocObserver &LocObserver) const override;
32+
bool legalizeIntrinsic(LegalizerHelper &Helper,
33+
MachineInstr &MI) const override;
34+
3235
SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
3336

3437
private:
3538
bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
3639
LostDebugLocObserver &LocObserver) const;
40+
bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const;
3741
};
3842
} // namespace llvm
3943
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H

0 commit comments

Comments
 (0)