1+ #include " clad/Differentiator/BaseForwardModeVisitor.h"
2+ #include " clad/Differentiator/Compatibility.h"
3+
4+ #include " clang/AST/DeclarationName.h"
5+ #include " clang/AST/OpenMPClause.h"
6+ #include " clang/AST/Stmt.h"
7+ #include " clang/AST/StmtOpenMP.h"
8+ #include " clang/Basic/LLVM.h"
9+ #include " clang/Basic/OpenMPKinds.h"
10+ #include " clang/Sema/DeclSpec.h"
11+ #include " clang/Sema/Sema.h"
12+ #include " llvm/ADT/SmallVector.h"
13+ #include " llvm/Frontend/OpenMP/OMP.h.inc"
14+
15+ #include < cassert>
16+
17+ using namespace clang ;
18+ using namespace llvm ::omp;
19+
20+ namespace clad {
21+
22+ OMPClause*
23+ BaseForwardModeVisitor::VisitOMPPrivateClause (const OMPPrivateClause* C) {
24+ llvm::SmallVector<Expr*, 16 > Vars;
25+ Vars.reserve (C->varlist_size ());
26+ for (const auto * Var : CLAD_COMPAT_CLANG20_getvarlist (C)) {
27+ Vars.push_back (Visit (Var).getExpr_dx ());
28+ Vars.push_back (Clone (Var));
29+ }
30+ return CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).ActOnOpenMPPrivateClause (
31+ Vars, C->getBeginLoc (), C->getLParenLoc (), C->getEndLoc ());
32+ }
33+
34+ OMPClause* BaseForwardModeVisitor::VisitOMPFirstprivateClause (
35+ const OMPFirstprivateClause* C) {
36+ llvm::SmallVector<Expr*, 16 > Vars;
37+ Vars.reserve (C->varlist_size ());
38+ for (const auto * Var : CLAD_COMPAT_CLANG20_getvarlist (C)) {
39+ Vars.push_back (Visit (Var).getExpr_dx ());
40+ Vars.push_back (Clone (Var));
41+ }
42+ return CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).ActOnOpenMPFirstprivateClause (
43+ Vars, C->getBeginLoc (), C->getLParenLoc (), C->getEndLoc ());
44+ }
45+
46+ OMPClause* BaseForwardModeVisitor::VisitOMPLastprivateClause (
47+ const OMPLastprivateClause* C) {
48+ llvm::SmallVector<Expr*, 16 > Vars;
49+ Vars.reserve (C->varlist_size ());
50+ for (const auto * Var : CLAD_COMPAT_CLANG20_getvarlist (C)) {
51+ Vars.push_back (Visit (Var).getExpr_dx ());
52+ Vars.push_back (Clone (Var));
53+ }
54+ return CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).ActOnOpenMPLastprivateClause (
55+ Vars, C->getKind (), C->getKindLoc (), C->getColonLoc (), C->getBeginLoc (),
56+ C->getLParenLoc (), C->getEndLoc ());
57+ }
58+
59+ OMPClause*
60+ BaseForwardModeVisitor::VisitOMPSharedClause (const OMPSharedClause* C) {
61+ llvm::SmallVector<Expr*, 16 > Vars;
62+ Vars.reserve (C->varlist_size ());
63+ for (const auto * Var : CLAD_COMPAT_CLANG20_getvarlist (C)) {
64+ Vars.push_back (Visit (Var).getExpr_dx ());
65+ Vars.push_back (Clone (Var));
66+ }
67+ return CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).ActOnOpenMPSharedClause (
68+ Vars, C->getBeginLoc (), C->getLParenLoc (), C->getEndLoc ());
69+ }
70+
71+ OMPClause*
72+ BaseForwardModeVisitor::VisitOMPReductionClause (const OMPReductionClause* C) {
73+ llvm::SmallVector<Expr*, 16 > Vars;
74+ Vars.reserve (C->varlist_size ());
75+ for (const auto * Var : CLAD_COMPAT_CLANG20_getvarlist (C)) {
76+ Vars.push_back (Visit (Var).getExpr_dx ());
77+ Vars.push_back (Clone (Var));
78+ }
79+ CXXScopeSpec ReductionIdScopeSpec;
80+ ReductionIdScopeSpec.Adopt (C->getQualifierLoc ());
81+ DeclarationNameInfo NameInfo = C->getNameInfo ();
82+ return CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).ActOnOpenMPReductionClause (
83+ Vars, CLAD_COMPAT_CLANG21_getModifier (C), C->getBeginLoc (),
84+ C->getLParenLoc (), C->getModifierLoc (), C->getColonLoc (), C->getEndLoc (),
85+ ReductionIdScopeSpec, NameInfo);
86+ }
87+
88+ StmtDiff BaseForwardModeVisitor::VisitOMPExecutableDirective (
89+ const OMPExecutableDirective* D) {
90+ // Transform the clauses
91+ llvm::SmallVector<OMPClause*, 16 > TClauses;
92+ ArrayRef<OMPClause*> Clauses = D->clauses ();
93+ TClauses.reserve (Clauses.size ());
94+ for (auto * I : Clauses) {
95+ assert (I);
96+ CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).StartOpenMPClause (
97+ I->getClauseKind ());
98+ OMPClause* Clause = Visit (I);
99+ CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).EndOpenMPClause ();
100+ assert (Clause);
101+ TClauses.push_back (Clause);
102+ }
103+ StmtDiff AssociatedStmt;
104+ if (D->hasAssociatedStmt () && D->getAssociatedStmt ()) {
105+ CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).ActOnOpenMPRegionStart (
106+ D->getDirectiveKind (), getCurrentScope ());
107+ StmtDiff Body;
108+ {
109+ Sema::CompoundScopeRAII CompoundScope (m_Sema);
110+ const auto * CS = D->getInnermostCapturedStmt ()->getCapturedStmt ();
111+ Body = Visit (CS);
112+ if (isOpenMPLoopDirective (D->getDirectiveKind ())) {
113+ // Ignore the differential variable generated by loop variables to
114+ // obtain a standard ForStmt.
115+ for (auto * SubStmt : Body.getStmt ()->children ()) {
116+ if (isa<ForStmt>(SubStmt)) {
117+ Body = SubStmt;
118+ break ;
119+ }
120+ }
121+ }
122+ }
123+ AssociatedStmt = CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema)
124+ .ActOnOpenMPRegionEnd (Body.getStmt (), TClauses)
125+ .get ();
126+ }
127+ DeclarationNameInfo DirName;
128+ OpenMPDirectiveKind CancelRegion = OMPD_unknown;
129+ return CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema)
130+ .ActOnOpenMPExecutableDirective (
131+ D->getDirectiveKind (), DirName, CancelRegion, TClauses,
132+ AssociatedStmt.getStmt (), D->getBeginLoc (), D->getEndLoc ())
133+ .get ();
134+ }
135+
136+ StmtDiff BaseForwardModeVisitor::VisitOMPParallelDirective (
137+ const OMPParallelDirective* D) {
138+ DeclarationNameInfo DirName;
139+ CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).StartOpenMPDSABlock (
140+ OMPD_parallel, DirName, getCurrentScope (), D->getBeginLoc ());
141+ StmtDiff Res = VisitOMPExecutableDirective (D);
142+ CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).EndOpenMPDSABlock (Res.getStmt ());
143+ return Res;
144+ }
145+
146+ StmtDiff BaseForwardModeVisitor::VisitOMPParallelForDirective (
147+ const OMPParallelForDirective* D) {
148+ DeclarationNameInfo DirName;
149+ CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).StartOpenMPDSABlock (
150+ OMPD_parallel_for, DirName, getCurrentScope (), D->getBeginLoc ());
151+ StmtDiff Res = VisitOMPExecutableDirective (D);
152+ CLAD_COMPAT_CLANG19_SemaOpenMP (m_Sema).EndOpenMPDSABlock (Res.getStmt ());
153+ return Res;
154+ }
155+ } // namespace clad
0 commit comments