@@ -498,6 +498,31 @@ static bool isTrivialFiller(Expr *E) {
498498 return false ;
499499}
500500
501+ static void EmitHLSLAggregateSplatCast (CodeGenFunction &CGF, Address DestVal,
502+ QualType DestTy, llvm::Value *SrcVal,
503+ QualType SrcTy, SourceLocation Loc) {
504+ // Flatten our destination
505+ SmallVector<QualType> DestTypes; // Flattened type
506+ SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
507+ // ^^ Flattened accesses to DestVal we want to store into
508+ CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
509+
510+ assert (SrcTy->isScalarType () && " Invalid HLSL Aggregate splat cast." );
511+ for (unsigned I = 0 , Size = StoreGEPList.size (); I < Size; ++I) {
512+ llvm::Value *Cast =
513+ CGF.EmitScalarConversion (SrcVal, SrcTy, DestTypes[I], Loc);
514+
515+ // store back
516+ llvm::Value *Idx = StoreGEPList[I].second ;
517+ if (Idx) {
518+ llvm::Value *V =
519+ CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
520+ Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
521+ }
522+ CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
523+ }
524+ }
525+
501526// emit a flat cast where the RHS is a scalar, including vector
502527static void EmitHLSLScalarFlatCast (CodeGenFunction &CGF, Address DestVal,
503528 QualType DestTy, llvm::Value *SrcVal,
@@ -970,6 +995,19 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
970995 case CK_HLSLArrayRValue:
971996 Visit (E->getSubExpr ());
972997 break ;
998+ case CK_HLSLAggregateSplatCast: {
999+ Expr *Src = E->getSubExpr ();
1000+ QualType SrcTy = Src->getType ();
1001+ RValue RV = CGF.EmitAnyExpr (Src);
1002+ QualType DestTy = E->getType ();
1003+ Address DestVal = Dest.getAddress ();
1004+ SourceLocation Loc = E->getExprLoc ();
1005+
1006+ assert (RV.isScalar () && " RHS of HLSL splat cast must be a scalar." );
1007+ llvm::Value *SrcVal = RV.getScalarVal ();
1008+ EmitHLSLAggregateSplatCast (CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
1009+ break ;
1010+ }
9731011 case CK_HLSLElementwiseCast: {
9741012 Expr *Src = E->getSubExpr ();
9751013 QualType SrcTy = Src->getType ();
@@ -1560,6 +1598,7 @@ static bool castPreservesZero(const CastExpr *CE) {
15601598 case CK_AtomicToNonAtomic:
15611599 case CK_HLSLVectorTruncation:
15621600 case CK_HLSLElementwiseCast:
1601+ case CK_HLSLAggregateSplatCast:
15631602 return true ;
15641603
15651604 case CK_BaseToDerivedMemberPointer:
0 commit comments