Skip to content

Commit d414451

Browse files
committed
[CUDA][HIP] Fix hostness check with -fopenmp
CUDA/HIP determines whether a function can be called based on the device/host attributes of callee and caller. Clang assumes the caller is CurContext. This is correct in most cases, however, it is not correct in OpenMP parallel region when CUDA/HIP program is compiled with -fopenmp. This causes incorrect overloading resolution and missed diagnostics. To get the correct caller, clang needs to chase the parent chain of DeclContext starting from CurContext until a function decl or a lambda decl is reached. Sema API is adapted to achieve that and used to determine the caller in hostness check. Reviewed by: Artem Belevich, Richard Smith Differential Revision: https://reviews.llvm.org/D121765
1 parent 2e44b78 commit d414451

File tree

6 files changed

+79
-30
lines changed

6 files changed

+79
-30
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3318,12 +3318,14 @@ class Sema final {
33183318
void ActOnReenterFunctionContext(Scope* S, Decl* D);
33193319
void ActOnExitFunctionContext();
33203320

3321-
DeclContext *getFunctionLevelDeclContext();
3322-
3323-
/// getCurFunctionDecl - If inside of a function body, this returns a pointer
3324-
/// to the function decl for the function being parsed. If we're currently
3325-
/// in a 'block', this returns the containing context.
3326-
FunctionDecl *getCurFunctionDecl();
3321+
/// If \p AllowLambda is true, treat lambda as function.
3322+
DeclContext *getFunctionLevelDeclContext(bool AllowLambda = false);
3323+
3324+
/// Returns a pointer to the innermost enclosing function, or nullptr if the
3325+
/// current context is not inside a function. If \p AllowLambda is true,
3326+
/// this can return the call operator of an enclosing lambda, otherwise
3327+
/// lambdas are skipped when looking for an enclosing function.
3328+
FunctionDecl *getCurFunctionDecl(bool AllowLambda = false);
33273329

33283330
/// getCurMethodDecl - If inside of a method body, this returns a pointer to
33293331
/// the method decl for the method being parsed. If we're currently

clang/lib/Sema/Sema.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,19 +1421,18 @@ void Sema::ActOnEndOfTranslationUnit() {
14211421
// Helper functions.
14221422
//===----------------------------------------------------------------------===//
14231423

1424-
DeclContext *Sema::getFunctionLevelDeclContext() {
1424+
DeclContext *Sema::getFunctionLevelDeclContext(bool AllowLambda) {
14251425
DeclContext *DC = CurContext;
14261426

14271427
while (true) {
14281428
if (isa<BlockDecl>(DC) || isa<EnumDecl>(DC) || isa<CapturedDecl>(DC) ||
14291429
isa<RequiresExprBodyDecl>(DC)) {
14301430
DC = DC->getParent();
1431-
} else if (isa<CXXMethodDecl>(DC) &&
1431+
} else if (!AllowLambda && isa<CXXMethodDecl>(DC) &&
14321432
cast<CXXMethodDecl>(DC)->getOverloadedOperator() == OO_Call &&
14331433
cast<CXXRecordDecl>(DC->getParent())->isLambda()) {
14341434
DC = DC->getParent()->getParent();
1435-
}
1436-
else break;
1435+
} else break;
14371436
}
14381437

14391438
return DC;
@@ -1442,8 +1441,8 @@ DeclContext *Sema::getFunctionLevelDeclContext() {
14421441
/// getCurFunctionDecl - If inside of a function body, this returns a pointer
14431442
/// to the function decl for the function being parsed. If we're currently
14441443
/// in a 'block', this returns the containing context.
1445-
FunctionDecl *Sema::getCurFunctionDecl() {
1446-
DeclContext *DC = getFunctionLevelDeclContext();
1444+
FunctionDecl *Sema::getCurFunctionDecl(bool AllowLambda) {
1445+
DeclContext *DC = getFunctionLevelDeclContext(AllowLambda);
14471446
return dyn_cast<FunctionDecl>(DC);
14481447
}
14491448

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,9 @@ void Sema::MaybeAddCUDAConstantAttr(VarDecl *VD) {
728728
Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc,
729729
unsigned DiagID) {
730730
assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
731+
FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true);
731732
SemaDiagnosticBuilder::Kind DiagKind = [&] {
732-
if (!isa<FunctionDecl>(CurContext))
733+
if (!CurFunContext)
733734
return SemaDiagnosticBuilder::K_Nop;
734735
switch (CurrentCUDATarget()) {
735736
case CFT_Global:
@@ -743,23 +744,23 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc,
743744
return SemaDiagnosticBuilder::K_Nop;
744745
if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
745746
return SemaDiagnosticBuilder::K_Immediate;
746-
return (getEmissionStatus(cast<FunctionDecl>(CurContext)) ==
747+
return (getEmissionStatus(CurFunContext) ==
747748
FunctionEmissionStatus::Emitted)
748749
? SemaDiagnosticBuilder::K_ImmediateWithCallStack
749750
: SemaDiagnosticBuilder::K_Deferred;
750751
default:
751752
return SemaDiagnosticBuilder::K_Nop;
752753
}
753754
}();
754-
return SemaDiagnosticBuilder(DiagKind, Loc, DiagID,
755-
dyn_cast<FunctionDecl>(CurContext), *this);
755+
return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this);
756756
}
757757

758758
Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc,
759759
unsigned DiagID) {
760760
assert(getLangOpts().CUDA && "Should only be called during CUDA compilation");
761+
FunctionDecl *CurFunContext = getCurFunctionDecl(/*AllowLambda=*/true);
761762
SemaDiagnosticBuilder::Kind DiagKind = [&] {
762-
if (!isa<FunctionDecl>(CurContext))
763+
if (!CurFunContext)
763764
return SemaDiagnosticBuilder::K_Nop;
764765
switch (CurrentCUDATarget()) {
765766
case CFT_Host:
@@ -772,16 +773,15 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc,
772773
return SemaDiagnosticBuilder::K_Nop;
773774
if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
774775
return SemaDiagnosticBuilder::K_Immediate;
775-
return (getEmissionStatus(cast<FunctionDecl>(CurContext)) ==
776+
return (getEmissionStatus(CurFunContext) ==
776777
FunctionEmissionStatus::Emitted)
777778
? SemaDiagnosticBuilder::K_ImmediateWithCallStack
778779
: SemaDiagnosticBuilder::K_Deferred;
779780
default:
780781
return SemaDiagnosticBuilder::K_Nop;
781782
}
782783
}();
783-
return SemaDiagnosticBuilder(DiagKind, Loc, DiagID,
784-
dyn_cast<FunctionDecl>(CurContext), *this);
784+
return SemaDiagnosticBuilder(DiagKind, Loc, DiagID, CurFunContext, *this);
785785
}
786786

787787
bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) {
@@ -794,7 +794,7 @@ bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) {
794794

795795
// FIXME: Is bailing out early correct here? Should we instead assume that
796796
// the caller is a global initializer?
797-
FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
797+
FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
798798
if (!Caller)
799799
return true;
800800

@@ -860,7 +860,7 @@ void Sema::CUDACheckLambdaCapture(CXXMethodDecl *Callee,
860860

861861
// File-scope lambda can only do init captures for global variables, which
862862
// results in passing by value for these global variables.
863-
FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext);
863+
FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true);
864864
if (!Caller)
865865
return;
866866

clang/lib/Sema/SemaOverload.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6473,7 +6473,7 @@ void Sema::AddOverloadCandidate(
64736473

64746474
// (CUDA B.1): Check for invalid calls between targets.
64756475
if (getLangOpts().CUDA)
6476-
if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext))
6476+
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
64776477
// Skip the check for callers that are implicit members, because in this
64786478
// case we may not yet know what the member's target is; the target is
64796479
// inferred for the member automatically, based on the bases and fields of
@@ -6983,7 +6983,7 @@ Sema::AddMethodCandidate(CXXMethodDecl *Method, DeclAccessPair FoundDecl,
69836983

69846984
// (CUDA B.1): Check for invalid calls between targets.
69856985
if (getLangOpts().CUDA)
6986-
if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext))
6986+
if (const FunctionDecl *Caller = getCurFunctionDecl(/*AllowLambda=*/true))
69876987
if (!IsAllowedCUDACall(Caller, Method)) {
69886988
Candidate.Viable = false;
69896989
Candidate.FailureKind = ovl_fail_bad_target;
@@ -9639,7 +9639,7 @@ bool clang::isBetterOverloadCandidate(
96399639
// overloading resolution diagnostics.
96409640
if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function &&
96419641
S.getLangOpts().GPUExcludeWrongSideOverloads) {
9642-
if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext)) {
9642+
if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
96439643
bool IsCallerImplicitHD = Sema::isCUDAImplicitHostDeviceFunction(Caller);
96449644
bool IsCand1ImplicitHD =
96459645
Sema::isCUDAImplicitHostDeviceFunction(Cand1.Function);
@@ -9922,7 +9922,7 @@ bool clang::isBetterOverloadCandidate(
99229922
// If other rules cannot determine which is better, CUDA preference is used
99239923
// to determine which is better.
99249924
if (S.getLangOpts().CUDA && Cand1.Function && Cand2.Function) {
9925-
FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
9925+
FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
99269926
return S.IdentifyCUDAPreference(Caller, Cand1.Function) >
99279927
S.IdentifyCUDAPreference(Caller, Cand2.Function);
99289928
}
@@ -10043,7 +10043,7 @@ OverloadCandidateSet::BestViableFunction(Sema &S, SourceLocation Loc,
1004310043
// -fgpu-exclude-wrong-side-overloads is on, all candidates are compared
1004410044
// uniformly in isBetterOverloadCandidate.
1004510045
if (S.getLangOpts().CUDA && !S.getLangOpts().GPUExcludeWrongSideOverloads) {
10046-
const FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext);
10046+
const FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
1004710047
bool ContainsSameSideCandidate =
1004810048
llvm::any_of(Candidates, [&](OverloadCandidate *Cand) {
1004910049
// Check viable function only.
@@ -11077,7 +11077,7 @@ static void DiagnoseBadDeduction(Sema &S, OverloadCandidate *Cand,
1107711077

1107811078
/// CUDA: diagnose an invalid call across targets.
1107911079
static void DiagnoseBadTarget(Sema &S, OverloadCandidate *Cand) {
11080-
FunctionDecl *Caller = cast<FunctionDecl>(S.CurContext);
11080+
FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true);
1108111081
FunctionDecl *Callee = Cand->Function;
1108211082

1108311083
Sema::CUDAFunctionTarget CallerTarget = S.IdentifyCUDATarget(Caller),
@@ -12136,7 +12136,7 @@ class AddressOfFunctionResolver {
1213612136

1213712137
if (FunctionDecl *FunDecl = dyn_cast<FunctionDecl>(Fn)) {
1213812138
if (S.getLangOpts().CUDA)
12139-
if (FunctionDecl *Caller = dyn_cast<FunctionDecl>(S.CurContext))
12139+
if (FunctionDecl *Caller = S.getCurFunctionDecl(/*AllowLambda=*/true))
1214012140
if (!Caller->isImplicit() && !S.IsAllowedCUDACall(Caller, FunDecl))
1214112141
return false;
1214212142
if (FunDecl->isMultiVersion()) {
@@ -12253,7 +12253,8 @@ class AddressOfFunctionResolver {
1225312253
}
1225412254

1225512255
void EliminateSuboptimalCudaMatches() {
12256-
S.EraseUnwantedCUDAMatches(dyn_cast<FunctionDecl>(S.CurContext), Matches);
12256+
S.EraseUnwantedCUDAMatches(S.getCurFunctionDecl(/*AllowLambda=*/true),
12257+
Matches);
1225712258
}
1225812259

1225912260
public:
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu \
2+
// RUN: -fopenmp -emit-llvm -o - -x hip %s | FileCheck %s
3+
4+
#include "Inputs/cuda.h"
5+
6+
void foo(double) {}
7+
__device__ void foo(int) {}
8+
9+
// Check foo resolves to the host function.
10+
// CHECK-LABEL: define {{.*}}@_Z5test1v
11+
// CHECK: call void @_Z3food(double noundef 1.000000e+00)
12+
void test1() {
13+
#pragma omp parallel
14+
for (int i = 0; i < 100; i++)
15+
foo(1);
16+
}
17+
18+
// Check foo resolves to the host function.
19+
// CHECK-LABEL: define {{.*}}@_Z5test2v
20+
// CHECK: call void @_Z3food(double noundef 1.000000e+00)
21+
void test2() {
22+
auto Lambda = []() {
23+
#pragma omp parallel
24+
for (int i = 0; i < 100; i++)
25+
foo(1);
26+
};
27+
Lambda();
28+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %clang_cc1 -fopenmp -fsyntax-only -verify %s
2+
3+
#include "Inputs/cuda.h"
4+
5+
__device__ void foo(int) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
6+
// expected-note@-1 {{'foo' declared here}}
7+
8+
int main() {
9+
#pragma omp parallel
10+
for (int i = 0; i < 100; i++)
11+
foo(1); // expected-error {{no matching function for call to 'foo'}}
12+
13+
auto Lambda = []() {
14+
#pragma omp parallel
15+
for (int i = 0; i < 100; i++)
16+
foo(1); // expected-error {{reference to __device__ function 'foo' in __host__ __device__ function}}
17+
};
18+
Lambda(); // expected-note {{called by 'main'}}
19+
}

0 commit comments

Comments
 (0)