1111#include " clang/AST/Decl.h"
1212#include " clang/AST/DeclCXX.h"
1313#include " clang/AST/Expr.h"
14+ #include " clang/AST/OperationKinds.h"
1415#include " clang/Basic/LLVM.h"
1516
1617#include < algorithm>
@@ -67,15 +68,43 @@ DerivativeAndOverload ReverseModeForwPassVisitor::Derive() {
6768 beginBlock ();
6869 beginBlock (direction::reverse);
6970
71+ // If we the differentiated function is a constructor, generate `this`
72+ // object and differentiate its inits.
73+ Stmt* ctorReturnStmt = nullptr ;
74+ if (const auto * CD = dyn_cast<CXXConstructorDecl>(m_DiffReq.Function )) {
75+ QualType thisTy = CD->getThisType ();
76+ StmtDiff thisObj = BuildThisExpr (thisTy);
77+ StmtDiff dthisObj = BuildThisExpr (thisTy, /* isDerivedThis=*/ true );
78+ m_ThisExprDerivative = dthisObj.getExpr ();
79+ addToCurrentBlock (dthisObj.getStmt_dx ());
80+ for (CXXCtorInitializer* CI : CD->inits ()) {
81+ StmtDiff CI_diff = DifferentiateCtorInit (CI, thisObj.getExpr ());
82+ addToCurrentBlock (CI_diff.getStmt (), direction::forward);
83+ addToCurrentBlock (CI_diff.getStmt_dx (), direction::forward);
84+ }
85+ // Build `return {*_this, *_d_this};`
86+ SourceLocation validLoc{CD->getBeginLoc ()};
87+ llvm::SmallVector<Expr*, 2 > returnArgs = {
88+ BuildOp (UO_Deref, thisObj.getExpr ()),
89+ BuildOp (UO_Deref, dthisObj.getExpr ())};
90+ Expr* returnInitList =
91+ m_Sema.ActOnInitList (validLoc, returnArgs, validLoc).get ();
92+ ctorReturnStmt = m_Sema.BuildReturnStmt (validLoc, returnInitList).get ();
93+ }
94+
7095 StmtDiff bodyDiff = Visit (m_DiffReq->getBody ());
7196 Stmt* forward = bodyDiff.getStmt ();
7297
7398 for (Stmt* S : ReverseModeVisitor::m_Globals)
7499 addToCurrentBlock (S);
75100
76- if (auto * CS = dyn_cast <CompoundStmt>(forward))
101+ if (auto * CS = dyn_cast_or_null <CompoundStmt>(forward))
77102 for (Stmt* S : CS->body ())
78103 addToCurrentBlock (S);
104+ else
105+ addToCurrentBlock (forward);
106+
107+ addToCurrentBlock (ctorReturnStmt);
79108
80109 Stmt* fnBody = endBlock ();
81110 m_Derivative->setBody (fnBody);
@@ -97,8 +126,13 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
97126
98127 std::size_t dParamTypesIdx = m_DiffReq->getNumParams ();
99128
100- if (const auto * CD = dyn_cast<CXXConversionDecl>(m_DiffReq.Function )) {
101- QualType typeTag = utils::GetCladTagOfType (m_Sema, CD->getConversionType ());
129+ QualType tagType;
130+ if (const auto * CD = dyn_cast<CXXConversionDecl>(m_DiffReq.Function ))
131+ tagType = CD->getConversionType ();
132+ else if (const auto * CD = dyn_cast<CXXConstructorDecl>(m_DiffReq.Function ))
133+ tagType = CD->getThisType ()->getPointeeType ();
134+ if (!tagType.isNull ()) {
135+ QualType typeTag = utils::GetCladTagOfType (m_Sema, tagType);
102136 IdentifierInfo* emptyII = &m_Context.Idents .get (" " );
103137 ParmVarDecl* typeTagPVD =
104138 utils::BuildParmVarDecl (m_Sema, m_Derivative, emptyII, typeTag);
@@ -110,7 +144,7 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
110144
111145 if (const auto * MD = dyn_cast<CXXMethodDecl>(m_DiffReq.Function )) {
112146 const CXXRecordDecl* RD = MD->getParent ();
113- if (MD->isInstance () && !RD->isLambda ()) {
147+ if (MD->isInstance () && !RD->isLambda () && !isa<CXXConstructorDecl>(MD) ) {
114148 auto * thisDerivativePVD = utils::BuildParmVarDecl (
115149 m_Sema, m_Derivative, CreateUniqueIdentifier (" _d_this" ),
116150 derivativeFnType->getParamType (dParamTypesIdx));
0 commit comments