Skip to content

Commit 274aaef

Browse files
committed
[Clang][HIP] Target-dependent overload resolution in declarators and specifiers
So far, the resolution of host/device overloads for functions in HIP/CUDA operates as if in a host-device context for code outside of function bodies, e.g., in expressions that are part of template arguments in top-level declarations. This means that, if separate host and device overloads are declared, the device overload is used in the device compilation phase and the host overload is used in the host compilation phase. This patch changes overload resolution in such cases to prefer overloads that match the target of the declaration in which they occur. For example: __device__ constexpr int get_n() { return 64; } __host__ constexpr int get_n() { return -1; } __device__ std::enable_if<(get_n() > 32)>::type foo() { } Before, this code would not compile, because get_n resolved to the host overload during host compilation, causing an error. With this patch, the call to get_n in the declaration of the device function foo resolves to the device overload in host and device compilation. If attributes that affect the declaration's target occur after a call with target-dependent overload resolution, a warning is issued. This is realized by registering the Kinds of relevant attributes in the CUDATargetContext when they are parsed. This is an alternative to PR llvm#93546, which is required for PR llvm#91478.
1 parent 2913e71 commit 274aaef

File tree

9 files changed

+988
-21
lines changed

9 files changed

+988
-21
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9017,6 +9017,10 @@ def err_global_call_not_config : Error<
90179017
def err_ref_bad_target : Error<
90189018
"reference to %select{__device__|__global__|__host__|__host__ __device__}0 "
90199019
"%select{function|variable}1 %2 in %select{__device__|__global__|__host__|__host__ __device__}3 function">;
9020+
def warn_target_specfier_ignored : Warning<
9021+
"target specifier has been ignored for overload resolution; "
9022+
"move the target specifier to the beginning of the declaration to use it for overload resolution">,
9023+
InGroup<IgnoredAttributes>;
90209024
def note_cuda_const_var_unpromoted : Note<
90219025
"const variable cannot be emitted on device side due to dynamic initialization">;
90229026
def note_cuda_host_var : Note<

clang/include/clang/Sema/SemaCUDA.h

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class SemaCUDA : public SemaBase {
104104
CUDAFunctionTarget IdentifyTarget(const FunctionDecl *D,
105105
bool IgnoreImplicitHDAttr = false);
106106
CUDAFunctionTarget IdentifyTarget(const ParsedAttributesView &Attrs);
107+
CUDAFunctionTarget IdentifyTarget(
108+
const SmallVectorImpl<clang::AttributeCommonInfo::Kind> &AttrKinds);
107109

108110
enum CUDAVariableTarget {
109111
CVT_Device, /// Emitted on device side with a shadow variable on host side
@@ -120,21 +122,43 @@ class SemaCUDA : public SemaBase {
120122
CTCK_Unknown, /// Unknown context
121123
CTCK_InitGlobalVar, /// Function called during global variable
122124
/// initialization
125+
CTCK_Declaration, /// Function called in a declaration specifier or
126+
/// declarator outside of other contexts, usually in
127+
/// template arguments.
123128
};
124129

125130
/// Define the current global CUDA host/device context where a function may be
126131
/// called. Only used when a function is called outside of any functions.
127-
struct CUDATargetContext {
128-
CUDAFunctionTarget Target = CUDAFunctionTarget::HostDevice;
132+
class CUDATargetContext {
133+
public:
129134
CUDATargetContextKind Kind = CTCK_Unknown;
130-
Decl *D = nullptr;
135+
136+
CUDATargetContext() = default;
137+
138+
CUDATargetContext(SemaCUDA *S, CUDATargetContextKind Kind,
139+
CUDAFunctionTarget Target);
140+
141+
CUDAFunctionTarget getTarget();
142+
143+
/// If this is a CTCK_Declaration context, update the Target based on Attrs.
144+
/// No-op otherwise.
145+
/// Issues a diagnostic if the target changes after it has been queried
146+
/// before.
147+
void tryRegisterTargetAttrs(const ParsedAttributesView &Attrs);
148+
149+
private:
150+
SemaCUDA *S = nullptr;
151+
CUDAFunctionTarget Target = CUDAFunctionTarget::HostDevice;
152+
SmallVector<clang::AttributeCommonInfo::Kind, 0> AttrKinds;
153+
bool TargetQueried = false;
154+
131155
} CurCUDATargetCtx;
132156

133157
struct CUDATargetContextRAII {
134158
SemaCUDA &S;
135159
SemaCUDA::CUDATargetContext SavedCtx;
136160
CUDATargetContextRAII(SemaCUDA &S_, SemaCUDA::CUDATargetContextKind K,
137-
Decl *D);
161+
Decl *D = nullptr);
138162
~CUDATargetContextRAII() { S.CurCUDATargetCtx = SavedCtx; }
139163
};
140164

clang/lib/Parse/ParseDecl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,9 @@ void Parser::ParseGNUAttributes(ParsedAttributes &Attrs,
311311
}
312312

313313
Attrs.Range = SourceRange(StartLoc, EndLoc);
314+
315+
if (Actions.getLangOpts().CUDA)
316+
Actions.CUDA().CurCUDATargetCtx.tryRegisterTargetAttrs(Attrs);
314317
}
315318

316319
/// Determine whether the given attribute has an identifier argument.
@@ -1003,6 +1006,9 @@ void Parser::ParseMicrosoftDeclSpecs(ParsedAttributes &Attrs) {
10031006
}
10041007

10051008
Attrs.Range = SourceRange(StartLoc, EndLoc);
1009+
1010+
if (Actions.getLangOpts().CUDA)
1011+
Actions.CUDA().CurCUDATargetCtx.tryRegisterTargetAttrs(Attrs);
10061012
}
10071013

10081014
void Parser::ParseMicrosoftTypeAttributes(ParsedAttributes &attrs) {

clang/lib/Parse/ParseDeclCXX.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "clang/Sema/EnterExpressionEvaluationContext.h"
2828
#include "clang/Sema/ParsedTemplate.h"
2929
#include "clang/Sema/Scope.h"
30+
#include "clang/Sema/SemaCUDA.h"
3031
#include "clang/Sema/SemaCodeCompletion.h"
3132
#include "llvm/ADT/SmallString.h"
3233
#include "llvm/Support/TimeProfiler.h"
@@ -2852,6 +2853,11 @@ Parser::DeclGroupPtrTy Parser::ParseCXXClassMemberDeclaration(
28522853
ParsedTemplateInfo &TemplateInfo, ParsingDeclRAIIObject *TemplateDiags) {
28532854
assert(getLangOpts().CPlusPlus &&
28542855
"ParseCXXClassMemberDeclaration should only be called in C++ mode");
2856+
SemaCUDA::CUDATargetContextRAII CTCRAII(Actions.CUDA(),
2857+
SemaCUDA::CTCK_Declaration);
2858+
if (Actions.getLangOpts().CUDA)
2859+
Actions.CUDA().CurCUDATargetCtx.tryRegisterTargetAttrs(AccessAttrs);
2860+
28552861
if (Tok.is(tok::at)) {
28562862
if (getLangOpts().ObjC && NextToken().isObjCAtKeyword(tok::objc_defs))
28572863
Diag(Tok, diag::err_at_defs_cxx);

clang/lib/Parse/Parser.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "clang/Sema/DeclSpec.h"
2222
#include "clang/Sema/ParsedTemplate.h"
2323
#include "clang/Sema/Scope.h"
24+
#include "clang/Sema/SemaCUDA.h"
2425
#include "clang/Sema/SemaCodeCompletion.h"
2526
#include "llvm/Support/Path.h"
2627
#include "llvm/Support/TimeProfiler.h"
@@ -1133,6 +1134,13 @@ bool Parser::isStartOfFunctionDefinition(const ParsingDeclarator &Declarator) {
11331134
Parser::DeclGroupPtrTy Parser::ParseDeclOrFunctionDefInternal(
11341135
ParsedAttributes &Attrs, ParsedAttributes &DeclSpecAttrs,
11351136
ParsingDeclSpec &DS, AccessSpecifier AS) {
1137+
SemaCUDA::CUDATargetContextRAII CTCRAII(Actions.CUDA(),
1138+
SemaCUDA::CTCK_Declaration);
1139+
if (Actions.getLangOpts().CUDA) {
1140+
Actions.CUDA().CurCUDATargetCtx.tryRegisterTargetAttrs(Attrs);
1141+
Actions.CUDA().CurCUDATargetCtx.tryRegisterTargetAttrs(DeclSpecAttrs);
1142+
}
1143+
11361144
// Because we assume that the DeclSpec has not yet been initialised, we simply
11371145
// overwrite the source range and attribute the provided leading declspec
11381146
// attributes.

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 97 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "clang/Basic/TargetInfo.h"
1919
#include "clang/Lex/Preprocessor.h"
2020
#include "clang/Sema/Lookup.h"
21+
#include "clang/Sema/ParsedAttr.h"
2122
#include "clang/Sema/ScopeInfo.h"
2223
#include "clang/Sema/Sema.h"
2324
#include "clang/Sema/SemaDiagnostic.h"
@@ -68,13 +69,28 @@ ExprResult SemaCUDA::ActOnExecConfigExpr(Scope *S, SourceLocation LLLLoc,
6869
/*IsExecConfig=*/true);
6970
}
7071

71-
CUDAFunctionTarget SemaCUDA::IdentifyTarget(const ParsedAttributesView &Attrs) {
72+
namespace {
73+
74+
// This iterator adaptor enables sharing a IdentifyTarget implementation for
75+
// ParsedAttributesView and for vectors of AttributeCommonInfo::Kind.
76+
struct AttrKindIterator
77+
: llvm::iterator_adaptor_base<
78+
AttrKindIterator, ParsedAttributesView::const_iterator,
79+
std::random_access_iterator_tag, clang::AttributeCommonInfo::Kind> {
80+
AttrKindIterator() : iterator_adaptor_base(nullptr) {}
81+
AttrKindIterator(ParsedAttributesView::const_iterator I)
82+
: iterator_adaptor_base(I) {}
83+
clang::AttributeCommonInfo::Kind operator*() const { return I->getKind(); }
84+
};
85+
86+
template <typename AKIterRange>
87+
CUDAFunctionTarget IdentifyTargetImpl(const AKIterRange &AttrKinds) {
7288
bool HasHostAttr = false;
7389
bool HasDeviceAttr = false;
7490
bool HasGlobalAttr = false;
7591
bool HasInvalidTargetAttr = false;
76-
for (const ParsedAttr &AL : Attrs) {
77-
switch (AL.getKind()) {
92+
for (const auto &AK : AttrKinds) {
93+
switch (AK) {
7894
case ParsedAttr::AT_CUDAGlobal:
7995
HasGlobalAttr = true;
8096
break;
@@ -107,6 +123,18 @@ CUDAFunctionTarget SemaCUDA::IdentifyTarget(const ParsedAttributesView &Attrs) {
107123
return CUDAFunctionTarget::Host;
108124
}
109125

126+
} // namespace
127+
128+
CUDAFunctionTarget SemaCUDA::IdentifyTarget(const ParsedAttributesView &Attrs) {
129+
return IdentifyTargetImpl(make_range(AttrKindIterator(Attrs.begin()),
130+
AttrKindIterator(Attrs.end())));
131+
}
132+
133+
CUDAFunctionTarget SemaCUDA::IdentifyTarget(
134+
const SmallVectorImpl<clang::AttributeCommonInfo::Kind> &AttrKinds) {
135+
return IdentifyTargetImpl(AttrKinds);
136+
}
137+
110138
template <typename A>
111139
static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
112140
return D->hasAttrs() && llvm::any_of(D->getAttrs(), [&](Attr *Attribute) {
@@ -115,20 +143,65 @@ static bool hasAttr(const Decl *D, bool IgnoreImplicitAttr) {
115143
});
116144
}
117145

146+
SemaCUDA::CUDATargetContext::CUDATargetContext(SemaCUDA *S,
147+
CUDATargetContextKind Kind,
148+
CUDAFunctionTarget Target)
149+
: Kind(Kind), S(S), Target(Target) {}
150+
151+
CUDAFunctionTarget SemaCUDA::CUDATargetContext::getTarget() {
152+
TargetQueried = true;
153+
return Target;
154+
}
155+
156+
void SemaCUDA::CUDATargetContext::tryRegisterTargetAttrs(
157+
const ParsedAttributesView &Attrs) {
158+
if (Kind != CTCK_Declaration)
159+
return;
160+
for (const auto &A : Attrs) {
161+
auto AK = A.getKind();
162+
switch (AK) {
163+
case ParsedAttr::AT_CUDAGlobal:
164+
case ParsedAttr::AT_CUDAHost:
165+
case ParsedAttr::AT_CUDADevice:
166+
case ParsedAttr::AT_CUDAInvalidTarget:
167+
break;
168+
default:
169+
continue;
170+
}
171+
AttrKinds.push_back(AK);
172+
CUDAFunctionTarget NewTarget = S->IdentifyTarget(AttrKinds);
173+
if (TargetQueried && (NewTarget != Target))
174+
S->Diag(A.getLoc(), diag::warn_target_specfier_ignored);
175+
Target = NewTarget;
176+
}
177+
}
178+
118179
SemaCUDA::CUDATargetContextRAII::CUDATargetContextRAII(
119180
SemaCUDA &S_, SemaCUDA::CUDATargetContextKind K, Decl *D)
120181
: S(S_) {
121182
SavedCtx = S.CurCUDATargetCtx;
122-
assert(K == SemaCUDA::CTCK_InitGlobalVar);
123-
auto *VD = dyn_cast_or_null<VarDecl>(D);
124-
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
125-
auto Target = CUDAFunctionTarget::Host;
126-
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
127-
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
128-
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
129-
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
130-
Target = CUDAFunctionTarget::Device;
131-
S.CurCUDATargetCtx = {Target, K, VD};
183+
184+
switch (K) {
185+
case SemaCUDA::CTCK_InitGlobalVar: {
186+
auto *VD = dyn_cast_or_null<VarDecl>(D);
187+
if (VD && VD->hasGlobalStorage() && !VD->isStaticLocal()) {
188+
auto Target = CUDAFunctionTarget::Host;
189+
if ((hasAttr<CUDADeviceAttr>(VD, /*IgnoreImplicit=*/true) &&
190+
!hasAttr<CUDAHostAttr>(VD, /*IgnoreImplicit=*/true)) ||
191+
hasAttr<CUDASharedAttr>(VD, /*IgnoreImplicit=*/true) ||
192+
hasAttr<CUDAConstantAttr>(VD, /*IgnoreImplicit=*/true))
193+
Target = CUDAFunctionTarget::Device;
194+
S.CurCUDATargetCtx = CUDATargetContext(&S, K, Target);
195+
}
196+
break;
197+
}
198+
case SemaCUDA::CTCK_Declaration:
199+
// The target is updated once relevant attributes are parsed. Initialize
200+
// with the target used if no attributes are present: Host.
201+
S.CurCUDATargetCtx = CUDATargetContext(&S, K, CUDAFunctionTarget::Host);
202+
break;
203+
default:
204+
llvm_unreachable("unexpected context kind");
132205
}
133206
}
134207

@@ -137,7 +210,7 @@ CUDAFunctionTarget SemaCUDA::IdentifyTarget(const FunctionDecl *D,
137210
bool IgnoreImplicitHDAttr) {
138211
// Code that lives outside a function gets the target from CurCUDATargetCtx.
139212
if (D == nullptr)
140-
return CurCUDATargetCtx.Target;
213+
return CurCUDATargetCtx.getTarget();
141214

142215
if (D->hasAttr<CUDAInvalidTargetAttr>())
143216
return CUDAFunctionTarget::InvalidTarget;
@@ -232,7 +305,7 @@ SemaCUDA::IdentifyPreference(const FunctionDecl *Caller,
232305
// trivial ctor/dtor without device attr to be used. Non-trivial ctor/dtor
233306
// will be diagnosed by checkAllowedInitializer.
234307
if (Caller == nullptr && CurCUDATargetCtx.Kind == CTCK_InitGlobalVar &&
235-
CurCUDATargetCtx.Target == CUDAFunctionTarget::Device &&
308+
CurCUDATargetCtx.getTarget() == CUDAFunctionTarget::Device &&
236309
(isa<CXXConstructorDecl>(Callee) || isa<CXXDestructorDecl>(Callee)))
237310
return CFP_HostDevice;
238311

@@ -297,8 +370,16 @@ SemaCUDA::IdentifyPreference(const FunctionDecl *Caller,
297370
(CallerTarget == CUDAFunctionTarget::Device &&
298371
CalleeTarget == CUDAFunctionTarget::Host) ||
299372
(CallerTarget == CUDAFunctionTarget::Global &&
300-
CalleeTarget == CUDAFunctionTarget::Host))
373+
CalleeTarget == CUDAFunctionTarget::Host)) {
374+
// In declaration contexts outside of function bodies and variable
375+
// initializers, tolerate mismatched function targets as long as they are
376+
// not codegened.
377+
if (CurCUDATargetCtx.Kind == CTCK_Declaration &&
378+
!this->SemaRef.getCurFunctionDecl(/*AllowLambda=*/true))
379+
return CFP_WrongSide;
380+
301381
return CFP_Never;
382+
}
302383

303384
llvm_unreachable("All cases should've been handled by now.");
304385
}

clang/lib/Sema/SemaOverload.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10747,7 +10747,7 @@ OverloadCandidateSet::BestViableFunction(Sema &S, SourceLocation Loc,
1074710747
llvm::any_of(Candidates, [&](OverloadCandidate *Cand) {
1074810748
// Check viable function only.
1074910749
return Cand->Viable && Cand->Function &&
10750-
S.CUDA().IdentifyPreference(Caller, Cand->Function) ==
10750+
S.CUDA().IdentifyPreference(Caller, Cand->Function) >=
1075110751
SemaCUDA::CFP_SameSide;
1075210752
});
1075310753
if (ContainsSameSideCandidate) {

0 commit comments

Comments
 (0)