Skip to content

Commit 4d98db7

Browse files
committed
Revert "Introduce AnalysisDeclContext to avoid repetetive CFG construction (#…"
This reverts commit a011819.
1 parent 2028122 commit 4d98db7

File tree

10 files changed

+68
-99
lines changed

10 files changed

+68
-99
lines changed

include/clad/Differentiator/DiffPlanner.h

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,15 @@
88
#include "clad/Differentiator/Timers.h"
99

1010
#include "clang/AST/Decl.h"
11-
#include "clang/AST/DeclBase.h"
1211
#include "clang/AST/ExprCXX.h"
1312
#include "clang/AST/RecursiveASTVisitor.h"
14-
#include "clang/Analysis/AnalysisDeclContext.h"
1513

1614
#include "llvm/ADT/DenseSet.h"
17-
#include "llvm/ADT/SmallVector.h"
1815
#include "llvm/Support/Compiler.h"
1916
#include "llvm/Support/SaveAndRestore.h"
2017
#include "llvm/Support/raw_ostream.h"
2118

2219
#include <iterator>
23-
#include <memory>
2420
#include <set>
2521

2622
namespace clang {
@@ -35,8 +31,7 @@ class Type;
3531
} // namespace clang
3632

3733
namespace clad {
38-
using OwnedAnalysisContexts =
39-
llvm::SmallVector<std::unique_ptr<clang::AnalysisDeclContext>, 4>;
34+
4035
/// A struct containing information about request to differentiate a function.
4136
struct DiffRequest {
4237
private:
@@ -128,8 +123,6 @@ struct DiffRequest {
128123
/// This will be particularly useful for pushforward and pullback functions.
129124
bool DeclarationOnly = false;
130125

131-
clang::AnalysisDeclContext* m_AnalysisDC;
132-
133126
/// Recomputes `DiffInputVarsInfo` using the current values of data members.
134127
///
135128
/// Differentiation parameters info is computed by parsing the argument
@@ -153,8 +146,6 @@ struct DiffRequest {
153146
// Note that CallContext is always different and we should ignore it.
154147
// CustomDerivative is an Expr* and is not always equal even if
155148
// the set of overloads is the same.
156-
// Including AnalysisDC would complicate constructing requests to find the
157-
// existing once.
158149
return Function == other.Function &&
159150
BaseFunctionName == other.BaseFunctionName &&
160151
CurrentDerivativeOrder == other.CurrentDerivativeOrder &&
@@ -187,10 +178,6 @@ struct DiffRequest {
187178
std::string ComputeDerivativeName() const;
188179
bool HasIndependentParameter(const clang::ParmVarDecl* PVD) const;
189180

190-
std::set<clang::SourceLocation>& getToBeRecorded() const {
191-
m_TbrRunInfo.HasAnalysisRun = true;
192-
return m_TbrRunInfo.ToBeRecorded;
193-
}
194181
void addVariedDecl(const clang::VarDecl* init) {
195182
m_ActivityRunInfo.VariedDecls.insert(init);
196183
}
@@ -203,7 +190,6 @@ struct DiffRequest {
203190
std::set<const clang::VarDecl*>& getUsefulDecls() const {
204191
return m_UsefulRunInfo.UsefulDecls;
205192
}
206-
bool HasTbrAnalysisRun() const { return m_TbrRunInfo.HasAnalysisRun; }
207193
};
208194

209195
using DiffInterval = std::vector<clang::SourceRange>;
@@ -224,10 +210,7 @@ struct DiffRequest {
224210
/// Graph to store the dependencies between different requests.
225211
///
226212
clad::DynamicGraph<DiffRequest>& m_DiffRequestGraph;
227-
/// Map that contains all AnalysisDeclContext for all declrations.
228-
/// Essentially needed for prolonging the lifetime of
229-
/// unique_ptr<clang::AnalysisDeclContext>.
230-
OwnedAnalysisContexts& m_AllAnalysisDC;
213+
231214
/// If set it means that we need to find the called functions and
232215
/// add them for implicit diff.
233216
///
@@ -245,7 +228,7 @@ struct DiffRequest {
245228
public:
246229
DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval,
247230
clad::DynamicGraph<DiffRequest>& requestGraph, clang::Sema& S,
248-
RequestOptions& opts, OwnedAnalysisContexts& AllAnalysisDC);
231+
RequestOptions& opts);
249232
bool VisitCallExpr(clang::CallExpr* E);
250233
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE);
251234
bool VisitCXXConstructExpr(clang::CXXConstructExpr* e);

lib/Differentiator/ActivityAnalyzer.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ using namespace clang;
66
namespace clad {
77

88
void VariedAnalyzer::Analyze(const FunctionDecl* FD) {
9-
m_BlockData.resize(m_AnalysisDC->getCFG()->size());
9+
// Build the CFG (control-flow graph) of FD.
10+
clang::CFG::BuildOptions Options;
11+
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);
12+
13+
m_BlockData.resize(m_CFG->size());
1014
// Set current block ID to the ID of entry the block.
11-
CFGBlock& entry = m_AnalysisDC->getCFG()->getEntry();
12-
m_CurBlockID = entry.getBlockID();
15+
CFGBlock* entry = &m_CFG->getEntry();
16+
m_CurBlockID = entry->getBlockID();
1317
m_BlockData[m_CurBlockID] = createNewVarsData({});
1418
for (const VarDecl* i : m_VariedDecls)
1519
m_BlockData[m_CurBlockID]->insert(i);
@@ -34,7 +38,7 @@ void mergeVarsData(std::set<const clang::VarDecl*>* targetData,
3438
}
3539

3640
CFGBlock* VariedAnalyzer::getCFGBlockByID(unsigned ID) {
37-
return *(m_AnalysisDC->getCFG()->begin() + ID);
41+
return *(m_CFG->begin() + ID);
3842
}
3943

4044
void VariedAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) {

lib/Differentiator/ActivityAnalyzer.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define CLAD_DIFFERENTIATOR_ACTIVITYANALYZER_H
33

44
#include "clang/AST/RecursiveASTVisitor.h"
5-
#include "clang/Analysis/AnalysisDeclContext.h"
65
#include "clang/Analysis/CFG.h"
76

87
#include "clad/Differentiator/CladUtils.h"
@@ -36,7 +35,8 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
3635

3736
clang::CFGBlock* getCFGBlockByID(unsigned ID);
3837

39-
clang::AnalysisDeclContext* m_AnalysisDC;
38+
clang::ASTContext& m_Context;
39+
std::unique_ptr<clang::CFG> m_CFG;
4040
std::vector<std::unique_ptr<VarsData>> m_BlockData;
4141
unsigned m_CurBlockID{};
4242
std::set<unsigned> m_CFGQueue;
@@ -55,9 +55,9 @@ class VariedAnalyzer : public clang::RecursiveASTVisitor<VariedAnalyzer> {
5555

5656
public:
5757
/// Constructor
58-
VariedAnalyzer(clang::AnalysisDeclContext* AnalysisDC,
58+
VariedAnalyzer(clang::ASTContext& Context,
5959
std::set<const clang::VarDecl*>& Decls)
60-
: m_VariedDecls(Decls), m_AnalysisDC(AnalysisDC) {}
60+
: m_VariedDecls(Decls), m_Context(Context) {}
6161

6262
/// Destructor
6363
~VariedAnalyzer() = default;

lib/Differentiator/DiffPlanner.cpp

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "clang/AST/ExprCXX.h"
2222
#include "clang/AST/ExprObjC.h"
2323
#include "clang/AST/RecursiveASTVisitor.h"
24-
#include "clang/Analysis/AnalysisDeclContext.h"
24+
#include "clang/Analysis/CallGraph.h"
2525
#include "clang/Basic/IdentifierTable.h"
2626
#include "clang/Basic/LLVM.h" // isa, dyn_cast
2727
#include "clang/Basic/SourceLocation.h"
@@ -32,7 +32,6 @@
3232
#include "clang/Sema/TemplateDeduction.h"
3333

3434
#include <algorithm>
35-
#include <memory>
3635
#include <string>
3736
#include <utility>
3837

@@ -274,10 +273,9 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
274273

275274
DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval,
276275
clad::DynamicGraph<DiffRequest>& requestGraph,
277-
clang::Sema& S, RequestOptions& opts,
278-
OwnedAnalysisContexts& AllAnalysisDC)
279-
: m_Interval(Interval), m_DiffRequestGraph(requestGraph),
280-
m_AllAnalysisDC(AllAnalysisDC), m_Sema(S), m_Options(opts) {
276+
clang::Sema& S, RequestOptions& opts)
277+
: m_Interval(Interval), m_DiffRequestGraph(requestGraph), m_Sema(S),
278+
m_Options(opts) {
281279

282280
if (Interval.empty())
283281
return;
@@ -650,11 +648,13 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
650648
if (E->getType()->isPointerType())
651649
return true;
652650

653-
if (!m_TbrRunInfo.HasAnalysisRun && !isLambdaCallOperator(Function) &&
654-
Function->isDefined()) {
651+
if (!m_TbrRunInfo.HasAnalysisRun && !isLambdaCallOperator(Function)) {
655652
TimedAnalysisRegion R("TBR " + BaseFunctionName);
656-
TBRAnalyzer analyzer(m_AnalysisDC, getToBeRecorded());
657-
analyzer.Analyze(*this);
653+
654+
TBRAnalyzer analyzer(Function->getASTContext(),
655+
m_TbrRunInfo.ToBeRecorded);
656+
analyzer.Analyze(Function);
657+
m_TbrRunInfo.HasAnalysisRun = true;
658658
}
659659
auto found = m_TbrRunInfo.ToBeRecorded.find(E->getBeginLoc());
660660
return found != m_TbrRunInfo.ToBeRecorded.end();
@@ -1263,27 +1263,22 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
12631263
}
12641264

12651265
if (!LookupCustomDerivativeDecl(request)) {
1266-
clang::CFG::BuildOptions Options;
1267-
std::unique_ptr<AnalysisDeclContext> AnalysisDC =
1268-
std::make_unique<AnalysisDeclContext>(
1269-
/*AnalysisDeclContextManager=*/nullptr, request.Function,
1270-
Options);
1271-
1272-
if (m_TopMostReq->EnableVariedAnalysis) {
1266+
if (m_TopMostReq->EnableVariedAnalysis &&
1267+
m_TopMostReq->Mode == DiffMode::reverse) {
12731268
TimedAnalysisRegion R("VA " + request.BaseFunctionName);
1274-
VariedAnalyzer analyzer(AnalysisDC.get(), request.getVariedDecls());
1269+
VariedAnalyzer analyzer(request.Function->getASTContext(),
1270+
request.getVariedDecls());
12751271
analyzer.Analyze(request.Function);
12761272
}
12771273

12781274
if (m_TopMostReq->EnableUsefulAnalysis) {
12791275
TimedAnalysisRegion R("UA " + request.BaseFunctionName);
1280-
UsefulAnalyzer analyzer(AnalysisDC.get(), request.getUsefulDecls());
1276+
1277+
UsefulAnalyzer analyzer(request.Function->getASTContext(),
1278+
request.getUsefulDecls());
12811279
analyzer.Analyze(request.Function);
12821280
}
12831281

1284-
m_AllAnalysisDC.push_back(std::move(AnalysisDC));
1285-
request.m_AnalysisDC = m_AllAnalysisDC.back().get();
1286-
12871282
// Recurse into call graph.
12881283
TraverseFunctionDeclOnce(request.Function);
12891284
}
@@ -1367,19 +1362,9 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) {
13671362
if (m_Sema.isStdInitializerList(recordTy, /*elemType=*/nullptr))
13681363
return true;
13691364

1370-
if (!LookupCustomDerivativeDecl(request)) {
1371-
clang::CFG::BuildOptions Options;
1372-
std::unique_ptr<AnalysisDeclContext> AnalysisDC =
1373-
std::make_unique<AnalysisDeclContext>(
1374-
/*AnalysisDeclContextManager=*/nullptr, request.Function,
1375-
Options);
1376-
// FIXME: Add proper support for objects in VA and UA.
1377-
m_AllAnalysisDC.push_back(std::move(AnalysisDC));
1378-
request.m_AnalysisDC = m_AllAnalysisDC.back().get();
1379-
1365+
if (!LookupCustomDerivativeDecl(request))
13801366
// Recurse into call graph.
13811367
TraverseFunctionDeclOnce(request.Function);
1382-
}
13831368
m_DiffRequestGraph.addNode(request, /*isSource=*/true);
13841369

13851370
return true;

lib/Differentiator/TBRAnalyzer.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88
#include "clang/AST/Expr.h"
99
#include "clang/AST/ExprCXX.h"
1010
#include "clang/AST/OperationKinds.h"
11-
#include "clang/Analysis/CFG.h"
1211
#include "clang/Basic/LLVM.h"
1312

14-
#include "clad/Differentiator/Compatibility.h"
15-
#include "clad/Differentiator/DiffPlanner.h"
16-
1713
#include "llvm/ADT/SmallVector.h"
1814
#include "llvm/Support/Casting.h"
1915
#include "llvm/Support/Debug.h"
@@ -255,7 +251,6 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD, bool forceNonRefType) {
255251
varType = arrayParam->getOriginalType();
256252
else
257253
varType = VD->getType();
258-
259254
// If varType represents auto or auto*, get the type of init.
260255
if (utils::IsAutoOrAutoPtrType(varType))
261256
varType = VD->getInit()->getType();
@@ -299,16 +294,19 @@ TBRAnalyzer::getVarDataFromDecl(const clang::VarDecl* VD) {
299294
return nullptr;
300295
}
301296

302-
void TBRAnalyzer::Analyze(const DiffRequest& request) {
303-
m_BlockData.resize(request.m_AnalysisDC->getCFG()->size());
304-
m_BlockPassCounter.resize(request.m_AnalysisDC->getCFG()->size(), 0);
297+
void TBRAnalyzer::Analyze(const FunctionDecl* FD) {
298+
// Build the CFG (control-flow graph) of FD.
299+
clang::CFG::BuildOptions Options;
300+
m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options);
301+
302+
m_BlockData.resize(m_CFG->size());
303+
m_BlockPassCounter.resize(m_CFG->size(), 0);
305304

306305
// Set current block ID to the ID of entry the block.
307-
CFGBlock& entry = request.m_AnalysisDC->getCFG()->getEntry();
308-
m_CurBlockID = entry.getBlockID();
306+
auto* entry = &m_CFG->getEntry();
307+
m_CurBlockID = entry->getBlockID();
309308
m_BlockData[m_CurBlockID] = std::unique_ptr<VarsData>(new VarsData());
310309

311-
const FunctionDecl* FD = request.Function;
312310
// If we are analysing a non-static method, add a VarData for 'this' pointer
313311
// (it is represented with nullptr).
314312
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
@@ -333,13 +331,13 @@ void TBRAnalyzer::Analyze(const DiffRequest& request) {
333331
m_CurBlockID = *IDIter;
334332
m_CFGQueue.erase(IDIter);
335333

336-
CFGBlock& nextBlock = *getCFGBlockByID(request.m_AnalysisDC, m_CurBlockID);
334+
CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID);
337335
VisitCFGBlock(nextBlock);
338336
}
339337
#ifndef NDEBUG
340338
for (int id = m_CurBlockID; id >= 0; --id) {
341339
LLVM_DEBUG(llvm::dbgs() << "\n-----BLOCK" << id << "-----\n\n");
342-
for (auto succ : getCFGBlockByID(request.m_AnalysisDC, id)->succs())
340+
for (auto succ : getCFGBlockByID(id)->succs())
343341
if (succ)
344342
LLVM_DEBUG(llvm::dbgs() << "successor: " << succ->getBlockID() << "\n");
345343
}
@@ -411,8 +409,8 @@ void TBRAnalyzer::VisitCFGBlock(const CFGBlock& block) {
411409
LLVM_DEBUG(llvm::dbgs() << "Leaving block " << block.getBlockID() << "\n");
412410
}
413411

414-
CFGBlock* TBRAnalyzer::getCFGBlockByID(AnalysisDeclContext* ADC, unsigned ID) {
415-
return *(ADC->getCFG()->begin() + ID);
412+
CFGBlock* TBRAnalyzer::getCFGBlockByID(unsigned ID) {
413+
return *(m_CFG->begin() + ID);
416414
}
417415

418416
TBRAnalyzer::VarsData*
@@ -568,11 +566,9 @@ bool TBRAnalyzer::TraverseDeclStmt(DeclStmt* DS) {
568566
if (auto* VD = dyn_cast<VarDecl>(D)) {
569567
addVar(VD);
570568
if (clang::Expr* init = VD->getInit()) {
571-
572569
setMode(Mode::kMarkingMode);
573570
TraverseStmt(init);
574571
resetMode();
575-
576572
auto& VDExpr = getCurBlockVarsData()[VD];
577573
// if the declared variable is ref type attach its VarData to the
578574
// VarData of the RHS variable.

0 commit comments

Comments
 (0)