Skip to content

Commit a5f8d06

Browse files
authored
Add basic support for OpenMP in forward mode (#1491)
1 parent 526d98e commit a5f8d06

File tree

8 files changed

+305
-3
lines changed

8 files changed

+305
-3
lines changed

include/clad/Differentiator/BaseForwardModeVisitor.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
#include "Compatibility.h"
55
#include "VisitorBase.h"
66
#include "clang/AST/ExprCXX.h"
7+
#include "clang/AST/OpenMPClause.h"
78
#include "clang/AST/RecursiveASTVisitor.h"
9+
#include "clang/AST/StmtOpenMP.h"
810
#include "clang/AST/StmtVisitor.h"
911
#include "clang/Sema/Sema.h"
1012

@@ -25,6 +27,8 @@ namespace clad {
2527
/// Used to compute derivatives by clad::differentiate.
2628
class BaseForwardModeVisitor
2729
: public clang::ConstStmtVisitor<BaseForwardModeVisitor, StmtDiff>,
30+
public clang::ConstOMPClauseVisitor<BaseForwardModeVisitor,
31+
clang::OMPClause*>,
2832
public VisitorBase {
2933
unsigned m_IndependentVarIndex = ~0;
3034

@@ -53,6 +57,11 @@ class BaseForwardModeVisitor
5357
return clang::ConstStmtVisitor<BaseForwardModeVisitor, StmtDiff>::Visit(S);
5458
}
5559

60+
clang::OMPClause* Visit(const clang::OMPClause* C) {
61+
return clang::ConstOMPClauseVisitor<BaseForwardModeVisitor,
62+
clang::OMPClause*>::Visit(C);
63+
}
64+
5665
virtual void ExecuteInsidePushforwardFunctionBlock() {}
5766

5867
virtual StmtDiff
@@ -123,6 +132,20 @@ class BaseForwardModeVisitor
123132
StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; };
124133
StmtDiff
125134
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
135+
136+
StmtDiff VisitOMPExecutableDirective(const clang::OMPExecutableDirective* D);
137+
StmtDiff VisitOMPParallelDirective(const clang::OMPParallelDirective* D);
138+
StmtDiff
139+
VisitOMPParallelForDirective(const clang::OMPParallelForDirective* D);
140+
clang::OMPClause* VisitOMPClause(const clang::OMPClause* C);
141+
clang::OMPClause* VisitOMPPrivateClause(const clang::OMPPrivateClause* C);
142+
clang::OMPClause*
143+
VisitOMPFirstprivateClause(const clang::OMPFirstprivateClause* C);
144+
clang::OMPClause*
145+
VisitOMPLastprivateClause(const clang::OMPLastprivateClause* C);
146+
clang::OMPClause* VisitOMPSharedClause(const clang::OMPSharedClause* C);
147+
clang::OMPClause* VisitOMPReductionClause(const clang::OMPReductionClause* C);
148+
126149
static DeclDiff<clang::StaticAssertDecl>
127150
DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD);
128151

include/clad/Differentiator/Compatibility.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include "clang/Basic/Version.h"
1616
#include "clang/Basic/SourceManager.h"
1717
#include "clang/Sema/Sema.h"
18+
#if CLANG_VERSION_MAJOR > 18
19+
#include "clang/Sema/SemaOpenMP.h"
20+
#endif
1821

1922
namespace clad_compat {
2023

@@ -48,13 +51,35 @@ using namespace llvm;
4851
#define CLAD_COMPAT_CLANG21_TemplateKeywordParam
4952
#endif
5053

54+
// clang-21 OpenMPReductionClauseModifiers got extra argument
55+
#if CLANG_VERSION_MAJOR < 21
56+
#define CLAD_COMPAT_CLANG21_getModifier(Clause) (Clause)->getModifier()
57+
#else
58+
#define CLAD_COMPAT_CLANG21_getModifier(Clause) \
59+
{(Clause)->getModifier(), (Clause)->getOriginalSharingModifier()}
60+
#endif
61+
62+
// clang-20 Clause varlist typo
63+
#if CLANG_VERSION_MAJOR < 20
64+
#define CLAD_COMPAT_CLANG20_getvarlist(Clause) (Clause)->varlists()
65+
#else
66+
#define CLAD_COMPAT_CLANG20_getvarlist(Clause) (Clause)->varlist()
67+
#endif
68+
5169
// clang-20 clang::Sema::AA_Casting became scoped
5270
#if CLANG_VERSION_MAJOR < 20
5371
#define CLAD_COMPAT_CLANG20_SemaAACasting clang::Sema::AA_Casting
5472
#else
5573
#define CLAD_COMPAT_CLANG20_SemaAACasting clang::AssignmentAction::Casting
5674
#endif
5775

76+
// clang-19 SemaOpenMP was introduced
77+
#if CLANG_VERSION_MAJOR < 19
78+
#define CLAD_COMPAT_CLANG19_SemaOpenMP(Sema) (Sema)
79+
#else
80+
#define CLAD_COMPAT_CLANG19_SemaOpenMP(Sema) ((Sema).OpenMP())
81+
#endif
82+
5883
// clang-18 CXXThisExpr got extra argument
5984

6085
#if CLANG_VERSION_MAJOR > 17

lib/Differentiator/BaseForwardModeVisitor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2197,4 +2197,4 @@ clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall(
21972197
customPushforwardName, customPushforwardArgs, getCurrentScope(), CE);
21982198
return pushforwardCall;
21992199
}
2200-
} // end namespace clad
2200+
} // end namespace clad
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#include "clad/Differentiator/BaseForwardModeVisitor.h"
2+
#include "clad/Differentiator/Compatibility.h"
3+
4+
#include "clang/AST/DeclarationName.h"
5+
#include "clang/AST/OpenMPClause.h"
6+
#include "clang/AST/Stmt.h"
7+
#include "clang/AST/StmtOpenMP.h"
8+
#include "clang/Basic/LLVM.h"
9+
#include "clang/Basic/OpenMPKinds.h"
10+
#include "clang/Sema/DeclSpec.h"
11+
#include "clang/Sema/Sema.h"
12+
#include "llvm/ADT/SmallVector.h"
13+
#include "llvm/Frontend/OpenMP/OMP.h.inc"
14+
15+
#include <cassert>
16+
17+
using namespace clang;
18+
using namespace llvm::omp;
19+
20+
namespace clad {
21+
22+
OMPClause*
23+
BaseForwardModeVisitor::VisitOMPPrivateClause(const OMPPrivateClause* C) {
24+
llvm::SmallVector<Expr*, 16> Vars;
25+
Vars.reserve(C->varlist_size());
26+
for (const auto* Var : CLAD_COMPAT_CLANG20_getvarlist(C)) {
27+
Vars.push_back(Visit(Var).getExpr_dx());
28+
Vars.push_back(Clone(Var));
29+
}
30+
return CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).ActOnOpenMPPrivateClause(
31+
Vars, C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
32+
}
33+
34+
OMPClause* BaseForwardModeVisitor::VisitOMPFirstprivateClause(
35+
const OMPFirstprivateClause* C) {
36+
llvm::SmallVector<Expr*, 16> Vars;
37+
Vars.reserve(C->varlist_size());
38+
for (const auto* Var : CLAD_COMPAT_CLANG20_getvarlist(C)) {
39+
Vars.push_back(Visit(Var).getExpr_dx());
40+
Vars.push_back(Clone(Var));
41+
}
42+
return CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).ActOnOpenMPFirstprivateClause(
43+
Vars, C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
44+
}
45+
46+
OMPClause* BaseForwardModeVisitor::VisitOMPLastprivateClause(
47+
const OMPLastprivateClause* C) {
48+
llvm::SmallVector<Expr*, 16> Vars;
49+
Vars.reserve(C->varlist_size());
50+
for (const auto* Var : CLAD_COMPAT_CLANG20_getvarlist(C)) {
51+
Vars.push_back(Visit(Var).getExpr_dx());
52+
Vars.push_back(Clone(Var));
53+
}
54+
return CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).ActOnOpenMPLastprivateClause(
55+
Vars, C->getKind(), C->getKindLoc(), C->getColonLoc(), C->getBeginLoc(),
56+
C->getLParenLoc(), C->getEndLoc());
57+
}
58+
59+
OMPClause*
60+
BaseForwardModeVisitor::VisitOMPSharedClause(const OMPSharedClause* C) {
61+
llvm::SmallVector<Expr*, 16> Vars;
62+
Vars.reserve(C->varlist_size());
63+
for (const auto* Var : CLAD_COMPAT_CLANG20_getvarlist(C)) {
64+
Vars.push_back(Visit(Var).getExpr_dx());
65+
Vars.push_back(Clone(Var));
66+
}
67+
return CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).ActOnOpenMPSharedClause(
68+
Vars, C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
69+
}
70+
71+
OMPClause*
72+
BaseForwardModeVisitor::VisitOMPReductionClause(const OMPReductionClause* C) {
73+
llvm::SmallVector<Expr*, 16> Vars;
74+
Vars.reserve(C->varlist_size());
75+
for (const auto* Var : CLAD_COMPAT_CLANG20_getvarlist(C)) {
76+
Vars.push_back(Visit(Var).getExpr_dx());
77+
Vars.push_back(Clone(Var));
78+
}
79+
CXXScopeSpec ReductionIdScopeSpec;
80+
ReductionIdScopeSpec.Adopt(C->getQualifierLoc());
81+
DeclarationNameInfo NameInfo = C->getNameInfo();
82+
return CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).ActOnOpenMPReductionClause(
83+
Vars, CLAD_COMPAT_CLANG21_getModifier(C), C->getBeginLoc(),
84+
C->getLParenLoc(), C->getModifierLoc(), C->getColonLoc(), C->getEndLoc(),
85+
ReductionIdScopeSpec, NameInfo);
86+
}
87+
88+
StmtDiff BaseForwardModeVisitor::VisitOMPExecutableDirective(
89+
const OMPExecutableDirective* D) {
90+
// Transform the clauses
91+
llvm::SmallVector<OMPClause*, 16> TClauses;
92+
ArrayRef<OMPClause*> Clauses = D->clauses();
93+
TClauses.reserve(Clauses.size());
94+
for (auto* I : Clauses) {
95+
assert(I);
96+
CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).StartOpenMPClause(
97+
I->getClauseKind());
98+
OMPClause* Clause = Visit(I);
99+
CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).EndOpenMPClause();
100+
assert(Clause);
101+
TClauses.push_back(Clause);
102+
}
103+
StmtDiff AssociatedStmt;
104+
if (D->hasAssociatedStmt() && D->getAssociatedStmt()) {
105+
CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).ActOnOpenMPRegionStart(
106+
D->getDirectiveKind(), getCurrentScope());
107+
StmtDiff Body;
108+
{
109+
Sema::CompoundScopeRAII CompoundScope(m_Sema);
110+
const auto* CS = D->getInnermostCapturedStmt()->getCapturedStmt();
111+
Body = Visit(CS);
112+
if (isOpenMPLoopDirective(D->getDirectiveKind())) {
113+
// Ignore the differential variable generated by loop variables to
114+
// obtain a standard ForStmt.
115+
for (auto* SubStmt : Body.getStmt()->children()) {
116+
if (isa<ForStmt>(SubStmt)) {
117+
Body = SubStmt;
118+
break;
119+
}
120+
}
121+
}
122+
}
123+
AssociatedStmt = CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema)
124+
.ActOnOpenMPRegionEnd(Body.getStmt(), TClauses)
125+
.get();
126+
}
127+
DeclarationNameInfo DirName;
128+
OpenMPDirectiveKind CancelRegion = OMPD_unknown;
129+
return CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema)
130+
.ActOnOpenMPExecutableDirective(
131+
D->getDirectiveKind(), DirName, CancelRegion, TClauses,
132+
AssociatedStmt.getStmt(), D->getBeginLoc(), D->getEndLoc())
133+
.get();
134+
}
135+
136+
StmtDiff BaseForwardModeVisitor::VisitOMPParallelDirective(
137+
const OMPParallelDirective* D) {
138+
DeclarationNameInfo DirName;
139+
CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).StartOpenMPDSABlock(
140+
OMPD_parallel, DirName, getCurrentScope(), D->getBeginLoc());
141+
StmtDiff Res = VisitOMPExecutableDirective(D);
142+
CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).EndOpenMPDSABlock(Res.getStmt());
143+
return Res;
144+
}
145+
146+
StmtDiff BaseForwardModeVisitor::VisitOMPParallelForDirective(
147+
const OMPParallelForDirective* D) {
148+
DeclarationNameInfo DirName;
149+
CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).StartOpenMPDSABlock(
150+
OMPD_parallel_for, DirName, getCurrentScope(), D->getBeginLoc());
151+
StmtDiff Res = VisitOMPExecutableDirective(D);
152+
CLAD_COMPAT_CLANG19_SemaOpenMP(m_Sema).EndOpenMPDSABlock(Res.getStmt());
153+
return Res;
154+
}
155+
} // namespace clad

lib/Differentiator/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ llvm_add_library(cladDifferentiator
5151
ActivityAnalyzer.cpp
5252
AnalysisBase.cpp
5353
BaseForwardModeVisitor.cpp
54+
BaseForwardModeVisitorOpenMP.cpp
5455
CladUtils.cpp
5556
ConstantFolder.cpp
5657
DerivativeBuilder.cpp

lib/Differentiator/StmtClone.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "clang/Sema/Lookup.h"
1515

1616
#include "llvm/ADT/SmallVector.h"
17+
#include "llvm/Support/Casting.h"
1718

1819
using namespace clang;
1920

@@ -549,7 +550,11 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
549550
// We should only update references of the declarations that were inside
550551
// the original function declaration context.
551552
// Original function = function that we are currently differentiating.
552-
if (!DRE->getDecl()->getDeclContext()->Encloses(m_Function))
553+
auto* Ctx = DRE->getDecl()->getDeclContext();
554+
// Skip the synthetic capture scope (e.g. OpenMP)
555+
while (isa<CapturedDecl>(Ctx))
556+
Ctx = Ctx->getParent();
557+
if (!Ctx->Encloses(m_Function))
553558
return true;
554559

555560
// Replace the declaration if it is present in `m_DeclReplacements`.
@@ -586,6 +591,7 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
586591
DRE->setDecl(VD);
587592
VD->setReferenced();
588593
VD->setIsUsed();
594+
m_Sema.MarkDeclarationsReferencedInExpr(DRE);
589595
}
590596
updateType(DRE->getType());
591597
return true;

test/ForwardMode/OpenMP.C

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// RUN: %cladclang %s -I%S/../../include -fopenmp -fsyntax-only -oOpenMP.out 2>&1 | %filecheck %s
2+
3+
#include "clad/Differentiator/Differentiator.h"
4+
5+
double sum_scaled_parallel_for(const double* x, int n, double scale) {
6+
double total = 0.0; // reduction target
7+
double tmp = 0.0; // private
8+
double last = 0.0; // lastprivate
9+
10+
#pragma omp parallel for shared(n) firstprivate(scale) private(tmp) \
11+
lastprivate(last) reduction(+:total)
12+
for (int i = 0; i < n; ++i) {
13+
tmp = x[i] * scale;
14+
total += tmp;
15+
if (i == n - 1)
16+
last = tmp; // lastprivate variable is assigned on the last iteration
17+
}
18+
19+
// Touch 'last' to avoid unused warnings without affecting results
20+
total += 0 * last;
21+
// Use only total in the return to keep math simple for derivatives
22+
return total;
23+
}
24+
25+
// CHECK: double sum_scaled_parallel_for_darg0_1(const double *x, int n, double scale) {
26+
// CHECK-NEXT: int _d_n = 0;
27+
// CHECK-NEXT: double _d_scale = 0;
28+
// CHECK-NEXT: double _d_total = 0.;
29+
// CHECK-NEXT: double total = 0.;
30+
// CHECK-NEXT: double _d_tmp = 0.;
31+
// CHECK-NEXT: double tmp = 0.;
32+
// CHECK-NEXT: double _d_last = 0.;
33+
// CHECK-NEXT: double last = 0.;
34+
// CHECK-NEXT: #pragma omp parallel for shared(_d_n,n) firstprivate(_d_scale,scale) private(_d_tmp,tmp) lastprivate(_d_last,last) reduction(+: _d_total,total)
35+
// CHECK-NEXT: for (int i = 0; i < n; ++i) {
36+
// CHECK-NEXT: _d_tmp = (i == 1.) * scale + x[i] * _d_scale;
37+
// CHECK-NEXT: tmp = x[i] * scale;
38+
// CHECK-NEXT: _d_total += _d_tmp;
39+
// CHECK-NEXT: total += tmp;
40+
// CHECK-NEXT: if (i == n - 1) {
41+
// CHECK-NEXT: _d_last = _d_tmp;
42+
// CHECK-NEXT: last = tmp;
43+
// CHECK-NEXT: }
44+
// CHECK-NEXT: }
45+
// CHECK-NEXT: _d_total += 0 * last + 0 * _d_last;
46+
// CHECK-NEXT: total += 0 * last;
47+
// CHECK-NEXT: return _d_total;
48+
// CHECK-NEXT: }
49+
50+
double accumulate_scale_parallel(double scale, int n) {
51+
52+
double sum = 0.0; // reduction target
53+
double tmp = 0.0; // private
54+
55+
#pragma omp parallel shared(n) firstprivate(scale) private(tmp) reduction(+:sum)
56+
{
57+
// Each thread contributes exactly once; with 1 thread this is sum += scale
58+
tmp = scale;
59+
sum += tmp;
60+
}
61+
62+
return sum;
63+
}
64+
65+
// CHECK: double accumulate_scale_parallel_darg0(double scale, int n) {
66+
// CHECK-NEXT: double _d_scale = 1;
67+
// CHECK-NEXT: int _d_n = 0;
68+
// CHECK-NEXT: double _d_sum = 0.;
69+
// CHECK-NEXT: double sum = 0.;
70+
// CHECK-NEXT: double _d_tmp = 0.;
71+
// CHECK-NEXT: double tmp = 0.;
72+
// CHECK-NEXT: #pragma omp parallel shared(_d_n,n) firstprivate(_d_scale,scale) private(_d_tmp,tmp) reduction(+: _d_sum,sum)
73+
// CHECK-NEXT: {
74+
// CHECK-NEXT: _d_tmp = _d_scale;
75+
// CHECK-NEXT: tmp = scale;
76+
// CHECK-NEXT: _d_sum += _d_tmp;
77+
// CHECK-NEXT: sum += tmp;
78+
// CHECK-NEXT: }
79+
// CHECK-NEXT: return _d_sum;
80+
// CHECK-NEXT: }
81+
82+
int main() {
83+
double x[5] = {1, 2, 3, 4, 5};
84+
int n = 5;
85+
double scale = 2.0;
86+
87+
auto d_sum_wrt_x1 = clad::differentiate(sum_scaled_parallel_for, "x[1]");
88+
double d2 = d_sum_wrt_x1.execute(x, n, scale);
89+
90+
auto d_accum_wrt_scale = clad::differentiate(accumulate_scale_parallel, "scale");
91+
double d3 = d_accum_wrt_scale.execute(3.5, 8);
92+
return 0;
93+
}

0 commit comments

Comments
 (0)