4040#include < utility>
4141
4242using namespace clang ;
43+ using llvm::dxil::ResourceClass;
44+
45+ enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
46+
47+ static RegisterType getRegisterType (ResourceClass RC) {
48+ switch (RC) {
49+ case ResourceClass::SRV:
50+ return RegisterType::SRV;
51+ case ResourceClass::UAV:
52+ return RegisterType::UAV;
53+ case ResourceClass::CBuffer:
54+ return RegisterType::CBuffer;
55+ case ResourceClass::Sampler:
56+ return RegisterType::Sampler;
57+ }
58+ llvm_unreachable (" unexpected ResourceClass value" );
59+ }
60+
61+ static RegisterType getRegisterType (StringRef Slot) {
62+ switch (Slot[0 ]) {
63+ case ' t' :
64+ case ' T' :
65+ return RegisterType::SRV;
66+ case ' u' :
67+ case ' U' :
68+ return RegisterType::UAV;
69+ case ' b' :
70+ case ' B' :
71+ return RegisterType::CBuffer;
72+ case ' s' :
73+ case ' S' :
74+ return RegisterType::Sampler;
75+ case ' c' :
76+ case ' C' :
77+ return RegisterType::C;
78+ case ' i' :
79+ case ' I' :
80+ return RegisterType::I;
81+ default :
82+ return RegisterType::Invalid;
83+ }
84+ }
4385
4486SemaHLSL::SemaHLSL (Sema &S) : SemaBase(S) {}
4587
@@ -586,8 +628,7 @@ bool clang::CreateHLSLAttributedResourceType(
586628 LocEnd = A->getRange ().getEnd ();
587629 switch (A->getKind ()) {
588630 case attr::HLSLResourceClass: {
589- llvm::dxil::ResourceClass RC =
590- cast<HLSLResourceClassAttr>(A)->getResourceClass ();
631+ ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass ();
591632 if (HasResourceClass) {
592633 S.Diag (A->getLocation (), ResAttrs.ResourceClass == RC
593634 ? diag::warn_duplicate_attribute_exact
@@ -672,7 +713,7 @@ bool SemaHLSL::handleResourceTypeAttr(const ParsedAttr &AL) {
672713 SourceLocation ArgLoc = Loc->Loc ;
673714
674715 // Validate resource class value
675- llvm::dxil:: ResourceClass RC;
716+ ResourceClass RC;
676717 if (!HLSLResourceClassAttr::ConvertStrToResourceClass (Identifier, RC)) {
677718 Diag (ArgLoc, diag::warn_attribute_type_not_supported)
678719 << " ResourceClass" << Identifier;
@@ -750,28 +791,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
750791 return LocInfo;
751792}
752793
753- struct RegisterBindingFlags {
754- bool Resource = false ;
755- bool UDT = false ;
756- bool Other = false ;
757- bool Basic = false ;
758-
759- bool SRV = false ;
760- bool UAV = false ;
761- bool CBV = false ;
762- bool Sampler = false ;
763-
764- bool ContainsNumeric = false ;
765- bool DefaultGlobals = false ;
766-
767- // used only when Resource == true
768- std::optional<llvm::dxil::ResourceClass> ResourceClass;
769- };
770-
771- static bool isDeclaredWithinCOrTBuffer (const Decl *TheDecl) {
772- return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext ());
773- }
774-
775794// get the record decl from a var decl that we expect
776795// represents a resource
777796static CXXRecordDecl *getRecordDeclFromVarDecl (VarDecl *VD) {
@@ -786,24 +805,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
786805 return TheRecordDecl;
787806}
788807
789- static void updateResourceClassFlagsFromDeclResourceClass (
790- RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
791- switch (DeclResourceClass) {
792- case llvm::hlsl::ResourceClass::SRV:
793- Flags.SRV = true ;
794- break ;
795- case llvm::hlsl::ResourceClass::UAV:
796- Flags.UAV = true ;
797- break ;
798- case llvm::hlsl::ResourceClass::CBuffer:
799- Flags.CBV = true ;
800- break ;
801- case llvm::hlsl::ResourceClass::Sampler:
802- Flags.Sampler = true ;
803- break ;
804- }
805- }
806-
807808const HLSLAttributedResourceType *
808809findAttributedResourceTypeOnField (VarDecl *VD) {
809810 assert (VD != nullptr && " expected VarDecl" );
@@ -817,8 +818,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
817818 return nullptr ;
818819}
819820
820- static void updateResourceClassFlagsFromRecordType (RegisterBindingFlags &Flags,
821- const RecordType *RT) {
821+ // Iterate over RecordType fields and return true if any of them matched the
822+ // register type
823+ static bool ContainsResourceForRegisterType (Sema &S, const RecordType *RT,
824+ RegisterType RegType) {
822825 llvm::SmallVector<const Type *> TypesToScan;
823826 TypesToScan.emplace_back (RT);
824827
@@ -827,8 +830,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
827830 while (T->isArrayType ())
828831 T = T->getArrayElementTypeNoTypeQual ();
829832 if (T->isIntegralOrEnumerationType () || T->isFloatingType ()) {
830- Flags. ContainsNumeric = true ;
831- continue ;
833+ if (RegType == RegisterType::C)
834+ return true ;
832835 }
833836 const RecordType *RT = T->getAs <RecordType>();
834837 if (!RT)
@@ -839,100 +842,84 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
839842 const Type *FieldTy = FD->getType ().getTypePtr ();
840843 if (const HLSLAttributedResourceType *AttrResType =
841844 dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
842- updateResourceClassFlagsFromDeclResourceClass (
843- Flags, AttrResType->getAttrs ().ResourceClass );
844- continue ;
845+ ResourceClass RC = AttrResType->getAttrs ().ResourceClass ;
846+ if (getRegisterType (RC) == RegType)
847+ return true ;
848+ } else {
849+ TypesToScan.emplace_back (FD->getType ().getTypePtr ());
845850 }
846- TypesToScan.emplace_back (FD->getType ().getTypePtr ());
847851 }
848852 }
853+ return false ;
849854}
850855
851- static RegisterBindingFlags HLSLFillRegisterBindingFlags (Sema &S,
852- Decl *TheDecl) {
853- RegisterBindingFlags Flags;
856+ static void CheckContainsResourceForRegisterType (Sema &S,
857+ SourceLocation &ArgLoc,
858+ Decl *D, RegisterType RegType,
859+ bool SpecifiedSpace) {
860+ int RegTypeNum = static_cast <int >(RegType);
854861
855862 // check if the decl type is groupshared
856- if (TheDecl ->hasAttr <HLSLGroupSharedAddressSpaceAttr>()) {
857- Flags. Other = true ;
858- return Flags ;
863+ if (D ->hasAttr <HLSLGroupSharedAddressSpaceAttr>()) {
864+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum ;
865+ return ;
859866 }
860867
861868 // Cbuffers and Tbuffers are HLSLBufferDecl types
862- if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
863- Flags.Resource = true ;
864- Flags.ResourceClass = CBufferOrTBuffer->isCBuffer ()
865- ? llvm::dxil::ResourceClass::CBuffer
866- : llvm::dxil::ResourceClass::SRV;
869+ if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
870+ ResourceClass RC = CBufferOrTBuffer->isCBuffer () ? ResourceClass::CBuffer
871+ : ResourceClass::SRV;
872+ if (RegType != getRegisterType (RC))
873+ S.Diag (D->getLocation (), diag::err_hlsl_binding_type_mismatch)
874+ << RegTypeNum;
875+ return ;
867876 }
877+
868878 // Samplers, UAVs, and SRVs are VarDecl types
869- else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
870- if (const HLSLAttributedResourceType *AttrResType =
871- findAttributedResourceTypeOnField (TheVarDecl)) {
872- Flags.Resource = true ;
873- Flags.ResourceClass = AttrResType->getAttrs ().ResourceClass ;
874- } else {
875- const clang::Type *TheBaseType = TheVarDecl->getType ().getTypePtr ();
876- while (TheBaseType->isArrayType ())
877- TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual ();
878-
879- if (TheBaseType->isArithmeticType ()) {
880- Flags.Basic = true ;
881- if (!isDeclaredWithinCOrTBuffer (TheDecl) &&
882- (TheBaseType->isIntegralType (S.getASTContext ()) ||
883- TheBaseType->isFloatingType ()))
884- Flags.DefaultGlobals = true ;
885- } else if (TheBaseType->isRecordType ()) {
886- Flags.UDT = true ;
887- const RecordType *TheRecordTy = TheBaseType->getAs <RecordType>();
888- updateResourceClassFlagsFromRecordType (Flags, TheRecordTy);
889- } else
890- Flags.Other = true ;
891- }
892- } else {
893- llvm_unreachable (" expected be VarDecl or HLSLBufferDecl" );
879+ assert (isa<VarDecl>(D) && " D is expected to be VarDecl or HLSLBufferDecl" );
880+ VarDecl *VD = cast<VarDecl>(D);
881+
882+ // Resource
883+ if (const HLSLAttributedResourceType *AttrResType =
884+ findAttributedResourceTypeOnField (VD)) {
885+ if (RegType != getRegisterType (AttrResType->getAttrs ().ResourceClass ))
886+ S.Diag (D->getLocation (), diag::err_hlsl_binding_type_mismatch)
887+ << RegTypeNum;
888+ return ;
894889 }
895- return Flags;
896- }
897890
898- enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
891+ const clang::Type *Ty = VD->getType ().getTypePtr ();
892+ while (Ty->isArrayType ())
893+ Ty = Ty->getArrayElementTypeNoTypeQual ();
899894
900- static RegisterType getRegisterType (llvm::dxil::ResourceClass RC) {
901- switch (RC) {
902- case llvm::dxil::ResourceClass::SRV:
903- return RegisterType::SRV;
904- case llvm::dxil::ResourceClass::UAV:
905- return RegisterType::UAV;
906- case llvm::dxil::ResourceClass::CBuffer:
907- return RegisterType::CBuffer;
908- case llvm::dxil::ResourceClass::Sampler:
909- return RegisterType::Sampler;
910- }
911- llvm_unreachable (" unexpected ResourceClass value" );
912- }
895+ // Basic types
896+ if (Ty->isArithmeticType ()) {
897+ bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext ());
898+ if (SpecifiedSpace && !DeclaredInCOrTBuffer)
899+ S.Diag (ArgLoc, diag::err_hlsl_space_on_global_constant);
913900
914- static RegisterType getRegisterType (StringRef Slot) {
915- switch (Slot[ 0 ] ) {
916- case ' t ' :
917- case ' T ' :
918- return RegisterType::SRV ;
919- case ' u ' :
920- case ' U ' :
921- return RegisterType::UAV;
922- case ' b ' :
923- case ' B ' :
924- return RegisterType::CBuffer;
925- case ' s ' :
926- case ' S ' :
927- return RegisterType::Sampler;
928- case ' c ' :
929- case ' C ' :
930- return RegisterType::C;
931- case ' i ' :
932- case ' I ' :
933- return RegisterType::I;
934- default :
935- return RegisterType::Invalid ;
901+ if (!DeclaredInCOrTBuffer &&
902+ (Ty-> isIntegralType (S. getASTContext ()) || Ty-> isFloatingType ()) ) {
903+ // Default Globals
904+ if (RegType == RegisterType::CBuffer)
905+ S. Diag (ArgLoc, diag::warn_hlsl_deprecated_register_type_b) ;
906+ else if (RegType != RegisterType::C)
907+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
908+ } else {
909+ if (RegType == RegisterType::C)
910+ S. Diag (ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
911+ else
912+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
913+ }
914+ } else if (Ty-> isRecordType ()) {
915+ // Class/struct types - walk the declaration and check each field and
916+ // subclass
917+ if (! ContainsResourceForRegisterType (S, Ty-> getAs <RecordType>(), RegType))
918+ S. Diag (D-> getLocation (), diag::warn_hlsl_user_defined_type_missing_member)
919+ << RegTypeNum;
920+ } else {
921+ // Anything else is an error
922+ S. Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum ;
936923 }
937924}
938925
@@ -969,76 +956,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
969956}
970957
971958static void DiagnoseHLSLRegisterAttribute (Sema &S, SourceLocation &ArgLoc,
972- Decl *TheDecl , RegisterType RegType,
973- const bool SpecifiedSpace) {
959+ Decl *D , RegisterType RegType,
960+ bool SpecifiedSpace) {
974961
975962 // exactly one of these two types should be set
976- assert (((isa<VarDecl>(TheDecl ) && !isa<HLSLBufferDecl>(TheDecl )) ||
977- (!isa<VarDecl>(TheDecl ) && isa<HLSLBufferDecl>(TheDecl ))) &&
963+ assert (((isa<VarDecl>(D ) && !isa<HLSLBufferDecl>(D )) ||
964+ (!isa<VarDecl>(D ) && isa<HLSLBufferDecl>(D ))) &&
978965 " expecting VarDecl or HLSLBufferDecl" );
979966
980- RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags (S, TheDecl);
981- assert ((int )Flags.Other + (int )Flags.Resource + (int )Flags.Basic +
982- (int )Flags.UDT ==
983- 1 &&
984- " only one resource analysis result should be expected" );
985-
986- int RegTypeNum = static_cast <int >(RegType);
987-
988- // first, if "other" is set, emit an error
989- if (Flags.Other ) {
990- S.Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
991- return ;
992- }
967+ // check if the declaration contains resource matching the register type
968+ CheckContainsResourceForRegisterType (S, ArgLoc, D, RegType, SpecifiedSpace);
993969
994970 // next, if multiple register annotations exist, check that none conflict.
995- ValidateMultipleRegisterAnnotations (S, TheDecl, RegType);
996-
997- // next, if resource is set, make sure the register type in the register
998- // annotation is compatible with the variable's resource type.
999- if (Flags.Resource ) {
1000- RegisterType ExpRegType = getRegisterType (Flags.ResourceClass .value ());
1001- if (RegType != ExpRegType) {
1002- S.Diag (TheDecl->getLocation (), diag::err_hlsl_binding_type_mismatch)
1003- << RegTypeNum;
1004- }
1005-
1006- return ;
1007- }
1008-
1009- // next, handle diagnostics for when the "basic" flag is set
1010- if (Flags.Basic ) {
1011- if (SpecifiedSpace && !isDeclaredWithinCOrTBuffer (TheDecl))
1012- S.Diag (ArgLoc, diag::err_hlsl_space_on_global_constant);
1013-
1014- if (Flags.DefaultGlobals ) {
1015- if (RegType == RegisterType::CBuffer)
1016- S.Diag (ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
1017- else if (RegType != RegisterType::C)
1018- S.Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1019- return ;
1020- }
1021-
1022- if (RegType == RegisterType::C)
1023- S.Diag (ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
1024- else
1025- S.Diag (ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1026-
1027- return ;
1028- }
1029-
1030- // finally, we handle the udt case
1031- if (Flags.UDT ) {
1032- const bool ExpectedRegisterTypesForUDT[] = {
1033- Flags.SRV , Flags.UAV , Flags.CBV , Flags.Sampler , Flags.ContainsNumeric };
1034- assert ((size_t )RegTypeNum < std::size (ExpectedRegisterTypesForUDT) &&
1035- " regType has unexpected value" );
1036-
1037- if (!ExpectedRegisterTypesForUDT[RegTypeNum])
1038- S.Diag (TheDecl->getLocation (),
1039- diag::warn_hlsl_user_defined_type_missing_member)
1040- << RegTypeNum;
1041- }
971+ ValidateMultipleRegisterAnnotations (S, D, RegType);
1042972}
1043973
1044974void SemaHLSL::handleResourceBindingAttr (Decl *TheDecl, const ParsedAttr &AL) {
0 commit comments