Skip to content

Commit 733b7bd

Browse files
spallsivan-shani
authored andcommitted
[HLSL] Implement HLSL Aggregate splatting (llvm#118992)
Implement HLSL Aggregate Splat casting that handles splatting for arrays and structs, and vectors if splatting from a vec1. Closes llvm#100609 and Closes llvm#100619 Depends on llvm#118842
1 parent 8921ea1 commit 733b7bd

18 files changed

+305
-4
lines changed

clang/include/clang/AST/OperationKinds.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
370370
// Aggregate by Value cast (HLSL only).
371371
CAST_OPERATION(HLSLElementwiseCast)
372372

373+
// Splat cast for Aggregates (HLSL only).
374+
CAST_OPERATION(HLSLAggregateSplatCast)
375+
373376
//===- Binary Operations -------------------------------------------------===//
374377
// Operators listed in order of precedence.
375378
// Note that additions to this should also update the StmtVisitor class,

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
144144
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
145145
bool ContainsBitField(QualType BaseTy);
146146
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
147+
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType);
147148
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
148149

149150
QualType getInoutParameterType(QualType Ty);

clang/lib/AST/Expr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,6 +1968,7 @@ bool CastExpr::CastConsistency() const {
19681968
case CK_HLSLArrayRValue:
19691969
case CK_HLSLVectorTruncation:
19701970
case CK_HLSLElementwiseCast:
1971+
case CK_HLSLAggregateSplatCast:
19711972
CheckNoBasePath:
19721973
assert(path_empty() && "Cast kind should not have a base path!");
19731974
break;

clang/lib/AST/ExprConstant.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15025,6 +15025,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
1502515025
case CK_FixedPointCast:
1502615026
case CK_IntegralToFixedPoint:
1502715027
case CK_MatrixCast:
15028+
case CK_HLSLAggregateSplatCast:
1502815029
llvm_unreachable("invalid cast kind for integral value");
1502915030

1503015031
case CK_BitCast:
@@ -15903,6 +15904,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1590315904
case CK_MatrixCast:
1590415905
case CK_HLSLVectorTruncation:
1590515906
case CK_HLSLElementwiseCast:
15907+
case CK_HLSLAggregateSplatCast:
1590615908
llvm_unreachable("invalid cast kind for complex value");
1590715909

1590815910
case CK_LValueToRValue:

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
53395339
case CK_HLSLVectorTruncation:
53405340
case CK_HLSLArrayRValue:
53415341
case CK_HLSLElementwiseCast:
5342+
case CK_HLSLAggregateSplatCast:
53425343
return EmitUnsupportedLValue(E, "unexpected cast lvalue");
53435344

53445345
case CK_Dependent:

clang/lib/CodeGen/CGExprAgg.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,31 @@ static bool isTrivialFiller(Expr *E) {
498498
return false;
499499
}
500500

501+
static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
502+
QualType DestTy, llvm::Value *SrcVal,
503+
QualType SrcTy, SourceLocation Loc) {
504+
// Flatten our destination
505+
SmallVector<QualType> DestTypes; // Flattened type
506+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
507+
// ^^ Flattened accesses to DestVal we want to store into
508+
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
509+
510+
assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
511+
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
512+
llvm::Value *Cast =
513+
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
514+
515+
// store back
516+
llvm::Value *Idx = StoreGEPList[I].second;
517+
if (Idx) {
518+
llvm::Value *V =
519+
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
520+
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
521+
}
522+
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
523+
}
524+
}
525+
501526
// emit a flat cast where the RHS is a scalar, including vector
502527
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
503528
QualType DestTy, llvm::Value *SrcVal,
@@ -970,6 +995,19 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
970995
case CK_HLSLArrayRValue:
971996
Visit(E->getSubExpr());
972997
break;
998+
case CK_HLSLAggregateSplatCast: {
999+
Expr *Src = E->getSubExpr();
1000+
QualType SrcTy = Src->getType();
1001+
RValue RV = CGF.EmitAnyExpr(Src);
1002+
QualType DestTy = E->getType();
1003+
Address DestVal = Dest.getAddress();
1004+
SourceLocation Loc = E->getExprLoc();
1005+
1006+
assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
1007+
llvm::Value *SrcVal = RV.getScalarVal();
1008+
EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
1009+
break;
1010+
}
9731011
case CK_HLSLElementwiseCast: {
9741012
Expr *Src = E->getSubExpr();
9751013
QualType SrcTy = Src->getType();
@@ -1560,6 +1598,7 @@ static bool castPreservesZero(const CastExpr *CE) {
15601598
case CK_AtomicToNonAtomic:
15611599
case CK_HLSLVectorTruncation:
15621600
case CK_HLSLElementwiseCast:
1601+
case CK_HLSLAggregateSplatCast:
15631602
return true;
15641603

15651604
case CK_BaseToDerivedMemberPointer:

clang/lib/CodeGen/CGExprComplex.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
611611
case CK_HLSLVectorTruncation:
612612
case CK_HLSLArrayRValue:
613613
case CK_HLSLElementwiseCast:
614+
case CK_HLSLAggregateSplatCast:
614615
llvm_unreachable("invalid cast kind for complex value");
615616

616617
case CK_FloatingRealToComplex:

clang/lib/CodeGen/CGExprConstant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,7 @@ class ConstExprEmitter
13361336
case CK_HLSLVectorTruncation:
13371337
case CK_HLSLArrayRValue:
13381338
case CK_HLSLElementwiseCast:
1339+
case CK_HLSLAggregateSplatCast:
13391340
return nullptr;
13401341
}
13411342
llvm_unreachable("Invalid CastKind");

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2643,6 +2643,11 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
26432643
return EmitScalarConversion(Visit(E), E->getType(), DestTy,
26442644
CE->getExprLoc());
26452645
}
2646+
// CK_HLSLAggregateSplatCast only handles splatting to vectors from a vec1
2647+
// Casts were inserted in Sema to Cast the Src Expr to a Scalar and
2648+
// To perform any necessary Scalar Cast, so this Cast can be handled
2649+
// by the regular Vector Splat cast code.
2650+
case CK_HLSLAggregateSplatCast:
26462651
case CK_VectorSplat: {
26472652
llvm::Type *DstTy = ConvertType(DestTy);
26482653
Value *Elt = Visit(const_cast<Expr *>(E));
@@ -2800,7 +2805,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
28002805
SourceLocation Loc = CE->getExprLoc();
28012806
QualType SrcTy = E->getType();
28022807

2803-
assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
2808+
assert(RV.isAggregate() && "Not a valid HLSL Elementwise Cast.");
28042809
// RHS is an aggregate
28052810
Address SrcVal = RV.getAggregateAddress();
28062811
return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);

clang/lib/Edit/RewriteObjCFoundationAPI.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
10861086

10871087
case CK_HLSLVectorTruncation:
10881088
case CK_HLSLElementwiseCast:
1089+
case CK_HLSLAggregateSplatCast:
10891090
llvm_unreachable("HLSL-specific cast in Objective-C?");
10901091
break;
10911092

0 commit comments

Comments
 (0)