Skip to content

Commit 9ed0b9e

Browse files
PetroZarytskyivgvassilev
authored andcommitted
Generate constructor_reverse_forw automatically
1 parent 922a9d2 commit 9ed0b9e

File tree

15 files changed

+390
-201
lines changed

15 files changed

+390
-201
lines changed

include/clad/Differentiator/CladUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,9 @@ namespace clad {
429429
bool isLinearConstructor(const clang::CXXConstructorDecl* CD,
430430
const clang::ASTContext& C);
431431

432+
bool isElidableConstructor(const clang::CXXConstructorDecl* CD,
433+
const clang::ASTContext& C);
434+
432435
/// Returns true if T allows to edit any memory.
433436
bool isMemoryType(clang::QualType T);
434437

include/clad/Differentiator/ReverseModeVisitor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,10 @@ namespace clad {
447447
///
448448
/// \param[in] thisTy `this` type.
449449
///
450+
/// \param[in] isDerivedThis if true, will build `_d_this`.
451+
///
450452
/// \returns {_this, free(_this)}
451-
StmtDiff BuildThisExpr(clang::QualType thisTy);
453+
StmtDiff BuildThisExpr(clang::QualType thisTy, bool isDerivedThis = false);
452454

453455
/// Helper function that checks whether the function to be derived
454456
/// is meant to be executed only by the GPU

include/clad/Differentiator/STLBuiltins.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,21 @@ size_pushforward(const ::std::array<T, N>* a,
466466
// vector reverse mode
467467
// more can be found in tests: test/Gradient/STLCustomDerivatives.C
468468

469+
template <typename T, typename... Args>
470+
clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
471+
constructor_reverse_forw(clad::Tag<::std::vector<T>>,
472+
Args...) elidable_reverse_forw;
473+
474+
template <typename T, typename... Args>
475+
clad::ValueAndAdjoint<::std::allocator<T>, ::std::allocator<T>>
476+
constructor_reverse_forw(clad::Tag<::std::allocator<T>>,
477+
Args...) elidable_reverse_forw;
478+
479+
template <typename T, typename U, typename... Args>
480+
clad::ValueAndAdjoint<::std::pair<T, U>, ::std::pair<T, U>>
481+
constructor_reverse_forw(clad::Tag<::std::pair<T, U>>,
482+
Args...) elidable_reverse_forw;
483+
469484
template <typename T, typename U, typename pU>
470485
void push_back_reverse_forw(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
471486
pU /*d_val*/) {

include/clad/Differentiator/ThrustBuiltins.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
namespace clad {
1010

11+
#define elidable_reverse_forw __attribute__((annotate("elidable_reverse_forw")))
12+
1113
// zero_init for thrust::device_vector
1214
template <class T> inline void zero_init(::thrust::device_vector<T>& v) {
1315
::thrust::fill(v.begin(), v.end(), T(0)); // NOLINT(misc-include-cleaner)
@@ -17,6 +19,14 @@ namespace custom_derivatives::class_functions {
1719

1820
// Constructors (reverse mode pullbacks)
1921

22+
template <typename T>
23+
clad::ValueAndAdjoint<::thrust::device_vector<T>, ::thrust::device_vector<T>>
24+
constructor_reverse_forw(clad::Tag<::thrust::device_vector<T>>,
25+
typename ::thrust::device_vector<T>::size_type count,
26+
const T& val,
27+
typename ::thrust::device_vector<T>::size_type d_count,
28+
const T& d_val) elidable_reverse_forw;
29+
2030
template <typename T>
2131
inline void
2232
constructor_pullback(typename ::thrust::device_vector<T>::size_type count,
@@ -29,6 +39,13 @@ constructor_pullback(typename ::thrust::device_vector<T>::size_type count,
2939
}
3040
}
3141

42+
template <typename T>
43+
clad::ValueAndAdjoint<::thrust::device_vector<T>, ::thrust::device_vector<T>>
44+
constructor_reverse_forw(clad::Tag<::thrust::device_vector<T>>,
45+
typename ::thrust::device_vector<T>::size_type count,
46+
typename ::thrust::device_vector<T>::size_type d_count)
47+
elidable_reverse_forw;
48+
3249
template <typename T>
3350
inline void
3451
constructor_pullback(typename ::thrust::device_vector<T>::size_type count,

lib/Differentiator/CladUtils.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,44 @@ namespace clad {
358358
return true;
359359
}
360360

361+
bool isElidableConstructor(const clang::CXXConstructorDecl* CD,
362+
const clang::ASTContext& C) {
363+
const CXXRecordDecl* RD = CD->getParent();
364+
if (CD->isCopyOrMoveConstructor() && CD->isTrivial())
365+
return true;
366+
if (RD->isAggregate())
367+
return true;
368+
if (utils::unwrapIfSingleStmt(CD->getBody()))
369+
return false;
370+
bool isZeroOrCopyCtor = true;
371+
for (CXXCtorInitializer* CI : CD->inits()) {
372+
if (CI->isMemberInitializer()) {
373+
QualType memberTy = CI->getMember()->getType();
374+
if (memberTy->isPointerType() || memberTy->isLValueReferenceType())
375+
continue;
376+
}
377+
Expr* init = CI->getInit()->IgnoreImplicit();
378+
Expr::EvalResult value;
379+
bool isZero = false;
380+
if (clad_compat::Expr_EvaluateAsConstantExpr(init, value, C)) {
381+
const auto* val = &value.Val;
382+
if (val->isArray() && val->hasArrayFiller())
383+
val = &val->getArrayFiller();
384+
if (val->isInt())
385+
isZero = val->getInt() == 0;
386+
else if (val->isFloat())
387+
isZero = val->getFloat().isZero();
388+
}
389+
390+
if (!(isa<DeclRefExpr>(init) || isa<CXXConstructExpr>(init) ||
391+
isZero)) {
392+
isZeroOrCopyCtor = false;
393+
break;
394+
}
395+
}
396+
return isZeroOrCopyCtor;
397+
}
398+
361399
void getRecordDeclFields(
362400
const clang::RecordDecl* RD,
363401
llvm::SmallVectorImpl<const clang::FieldDecl*>& fields) {

lib/Differentiator/DiffPlanner.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,10 @@ static QualType GetDerivedFunctionType(const CallExpr* CE) {
14561456
forwPassRequest.BaseFunctionName = "constructor";
14571457
forwPassRequest.Mode = DiffMode::reverse_mode_forward_pass;
14581458
forwPassRequest.CallContext = E;
1459-
if (LookupCustomDerivativeDecl(forwPassRequest))
1459+
QualType recordTy = CD->getThisType()->getPointeeType();
1460+
bool elideRevForw =
1461+
utils::isElidableConstructor(CD, m_Sema.getASTContext());
1462+
if (LookupCustomDerivativeDecl(forwPassRequest) || !elideRevForw)
14601463
m_DiffRequestGraph.addNode(forwPassRequest, /*isSource=*/true);
14611464

14621465
// Don't build propagators for calls that do not contribute in
@@ -1488,7 +1491,6 @@ static QualType GetDerivedFunctionType(const CallExpr* CE) {
14881491
if (request.Function->getDefinition())
14891492
request.Function = request.Function->getDefinition();
14901493

1491-
QualType recordTy = CD->getThisType()->getPointeeType();
14921494
if (m_Sema.isStdInitializerList(recordTy, /*elemType=*/nullptr))
14931495
return true;
14941496

lib/Differentiator/ReverseModeForwPassVisitor.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "clang/AST/Decl.h"
1212
#include "clang/AST/DeclCXX.h"
1313
#include "clang/AST/Expr.h"
14+
#include "clang/AST/OperationKinds.h"
1415
#include "clang/Basic/LLVM.h"
1516

1617
#include <algorithm>
@@ -67,15 +68,43 @@ DerivativeAndOverload ReverseModeForwPassVisitor::Derive() {
6768
beginBlock();
6869
beginBlock(direction::reverse);
6970

71+
// If we the differentiated function is a constructor, generate `this`
72+
// object and differentiate its inits.
73+
Stmt* ctorReturnStmt = nullptr;
74+
if (const auto* CD = dyn_cast<CXXConstructorDecl>(m_DiffReq.Function)) {
75+
QualType thisTy = CD->getThisType();
76+
StmtDiff thisObj = BuildThisExpr(thisTy);
77+
StmtDiff dthisObj = BuildThisExpr(thisTy, /*isDerivedThis=*/true);
78+
m_ThisExprDerivative = dthisObj.getExpr();
79+
addToCurrentBlock(dthisObj.getStmt_dx());
80+
for (CXXCtorInitializer* CI : CD->inits()) {
81+
StmtDiff CI_diff = DifferentiateCtorInit(CI, thisObj.getExpr());
82+
addToCurrentBlock(CI_diff.getStmt(), direction::forward);
83+
addToCurrentBlock(CI_diff.getStmt_dx(), direction::forward);
84+
}
85+
// Build `return {*_this, *_d_this};`
86+
SourceLocation validLoc{CD->getBeginLoc()};
87+
llvm::SmallVector<Expr*, 2> returnArgs = {
88+
BuildOp(UO_Deref, thisObj.getExpr()),
89+
BuildOp(UO_Deref, dthisObj.getExpr())};
90+
Expr* returnInitList =
91+
m_Sema.ActOnInitList(validLoc, returnArgs, validLoc).get();
92+
ctorReturnStmt = m_Sema.BuildReturnStmt(validLoc, returnInitList).get();
93+
}
94+
7095
StmtDiff bodyDiff = Visit(m_DiffReq->getBody());
7196
Stmt* forward = bodyDiff.getStmt();
7297

7398
for (Stmt* S : ReverseModeVisitor::m_Globals)
7499
addToCurrentBlock(S);
75100

76-
if (auto* CS = dyn_cast<CompoundStmt>(forward))
101+
if (auto* CS = dyn_cast_or_null<CompoundStmt>(forward))
77102
for (Stmt* S : CS->body())
78103
addToCurrentBlock(S);
104+
else
105+
addToCurrentBlock(forward);
106+
107+
addToCurrentBlock(ctorReturnStmt);
79108

80109
Stmt* fnBody = endBlock();
81110
m_Derivative->setBody(fnBody);
@@ -97,8 +126,13 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
97126

98127
std::size_t dParamTypesIdx = m_DiffReq->getNumParams();
99128

100-
if (const auto* CD = dyn_cast<CXXConversionDecl>(m_DiffReq.Function)) {
101-
QualType typeTag = utils::GetCladTagOfType(m_Sema, CD->getConversionType());
129+
QualType tagType;
130+
if (const auto* CD = dyn_cast<CXXConversionDecl>(m_DiffReq.Function))
131+
tagType = CD->getConversionType();
132+
else if (const auto* CD = dyn_cast<CXXConstructorDecl>(m_DiffReq.Function))
133+
tagType = CD->getThisType()->getPointeeType();
134+
if (!tagType.isNull()) {
135+
QualType typeTag = utils::GetCladTagOfType(m_Sema, tagType);
102136
IdentifierInfo* emptyII = &m_Context.Idents.get("");
103137
ParmVarDecl* typeTagPVD =
104138
utils::BuildParmVarDecl(m_Sema, m_Derivative, emptyII, typeTag);
@@ -110,7 +144,7 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
110144

111145
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function)) {
112146
const CXXRecordDecl* RD = MD->getParent();
113-
if (MD->isInstance() && !RD->isLambda()) {
147+
if (MD->isInstance() && !RD->isLambda() && !isa<CXXConstructorDecl>(MD)) {
114148
auto* thisDerivativePVD = utils::BuildParmVarDecl(
115149
m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"),
116150
derivativeFnType->getParamType(dParamTypesIdx));

0 commit comments

Comments
 (0)