Skip to content

Commit 2028122

Browse files
authored
Fix #1451: Support more complex subscript operator use (#1452)
* Support more complex subscript operator use When the result of the subscript operator may affect the derivative, such as with `list[i].modify(x)`, clad currently marks it as nondifferentiable and skips creating the derivative/adjoint. This makes it follow the usual differentiation path. Closes #1451.
1 parent 89407d1 commit 2028122

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

lib/Differentiator/ReverseModeVisitor.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "clang/AST/TemplateBase.h"
3030
#include "clang/AST/Type.h"
3131
#include "clang/Basic/LLVM.h" // for clang::isa
32+
#include "clang/Basic/OperatorKinds.h"
3233
#include "clang/Basic/SourceLocation.h"
3334
#include "clang/Basic/TargetInfo.h"
3435
#include "clang/Basic/TokenKinds.h"
@@ -1591,8 +1592,16 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
15911592
// derived function. In the case of member functions, `implicit`
15921593
// this object is always passed by reference.
15931594
if (!nonDiff && !dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
1594-
(!baseOriginalE || MD->isConst()))
1595-
nonDiff = true;
1595+
(!baseOriginalE || MD->isConst())) {
1596+
// The result of the subscript operator may affect the derivative, such as
1597+
// in a case like `list[i].modify(x)`. This makes clad handle those
1598+
// normally.
1599+
if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
1600+
if (OCE->getOperator() != clang::OverloadedOperatorKind::OO_Subscript)
1601+
nonDiff = true;
1602+
} else
1603+
nonDiff = true;
1604+
}
15961605

15971606
// If all arguments are constant literals, then this does not contribute to
15981607
// the gradient.
@@ -2073,6 +2082,15 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
20732082
} // Recreate the original call expression.
20742083

20752084
if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
2085+
if (OCE->getOperator() == clang::OverloadedOperatorKind::OO_Subscript) {
2086+
// If the operator is subscript, we should return the adjoint expression
2087+
auto AdjointCallArgs = CallArgs;
2088+
CallArgs.insert(CallArgs.begin(), baseDiff.getExpr());
2089+
AdjointCallArgs.insert(AdjointCallArgs.begin(), baseDiff.getExpr_dx());
2090+
call = BuildOperatorCall(OCE->getOperator(), CallArgs);
2091+
Expr* call_dx = BuildOperatorCall(OCE->getOperator(), AdjointCallArgs);
2092+
return StmtDiff(call, call_dx);
2093+
}
20762094
if (isMethodOperatorCall)
20772095
CallArgs.insert(CallArgs.begin(), baseDiff.getExpr());
20782096
call = BuildOperatorCall(OCE->getOperator(), CallArgs);

test/Gradient/Loops.C

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "clad/Differentiator/Differentiator.h"
77
#include <cmath>
8+
#include <vector>
89

910
#include "../TestUtils.h"
1011

@@ -3260,6 +3261,67 @@ double fn41(double u, double v) {
32603261
//CHECK-NEXT: }
32613262
//CHECK-NEXT:}
32623263

3264+
struct tmp {
3265+
float z = 0;
3266+
tmp(float val) : z(val) {}
3267+
tmp() = default;
3268+
void operator+=(const tmp &other) {
3269+
z += other.z;
3270+
}
3271+
float forward(const float &x) const {
3272+
return x + z;
3273+
}
3274+
};
3275+
struct layer {
3276+
std::vector<tmp> w;
3277+
float forward(const float &inp) const {
3278+
float x = inp;
3279+
for (int i=0;i<w.size();i++) {
3280+
x = w[i].forward(x);
3281+
}
3282+
return x;
3283+
}
3284+
};
3285+
float fn42(const layer &l, float x) {
3286+
return l.forward(x);
3287+
}
3288+
//CHECK: void forward_pullback(const float &inp, float _d_y, layer *_d_this, float *_d_inp) const {
3289+
//CHECK-NEXT: int _d_i = 0;
3290+
//CHECK-NEXT: int i = 0;
3291+
//CHECK-NEXT: clad::tape<float> _t1 = {};
3292+
//CHECK-NEXT: float _d_x = 0.F;
3293+
//CHECK-NEXT: float x = inp;
3294+
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}};
3295+
//CHECK-NEXT: for (i = 0; ; i++) {
3296+
//CHECK-NEXT: {
3297+
//CHECK-NEXT: if (!(i < this->w.size()))
3298+
//CHECK-NEXT: break;
3299+
//CHECK-NEXT: }
3300+
//CHECK-NEXT: _t0++;
3301+
//CHECK-NEXT: clad::push(_t1, x);
3302+
//CHECK-NEXT: x = this->w[i].forward(x);
3303+
//CHECK-NEXT: }
3304+
//CHECK-NEXT: _d_x += _d_y;
3305+
//CHECK-NEXT: for (;; _t0--) {
3306+
//CHECK-NEXT: {
3307+
//CHECK-NEXT: if (!_t0)
3308+
//CHECK-NEXT: break;
3309+
//CHECK-NEXT: }
3310+
//CHECK-NEXT: i--;
3311+
//CHECK-NEXT: {
3312+
//CHECK-NEXT: x = clad::pop(_t1);
3313+
//CHECK-NEXT: float _r_d0 = _d_x;
3314+
//CHECK-NEXT: _d_x = 0.F;
3315+
//CHECK-NEXT: this->w[i].forward_pullback(x, _r_d0, &_d_this->w[i], &_d_x);
3316+
//CHECK-NEXT: size_type _r0 = {{0U|0UL}};
3317+
//CHECK-NEXT: this->w.operator_subscript_pullback(i, {}, &_d_this->w, &_r0);
3318+
//CHECK-NEXT: _d_i += _r0;
3319+
//CHECK-NEXT: }
3320+
//CHECK-NEXT: }
3321+
//CHECK-NEXT: *_d_inp += _d_x;
3322+
//CHECK-NEXT:}
3323+
3324+
32633325
#define TEST(F, x) { \
32643326
result[0] = 0; \
32653327
auto F##grad = clad::gradient(F);\
@@ -3352,4 +3414,11 @@ int main() {
33523414
TEST(fn39, 9); // CHECK-EXEC: {6.00}
33533415
TEST_2(fn40, 2, 3); // CHECK-EXEC: {14.00, 0.00}
33543416
TEST_2(fn41, 2, 3); // CHECK-EXEC: {1.00, 0.00}
3417+
3418+
auto d_fn42 = clad::gradient(fn42, "0");
3419+
float x_ = 2.0f;
3420+
layer l{ .w = {{3}, {4}, {5}, {6}}};
3421+
layer d_l{ .w = {{0}, {0}, {0}, {0}}};
3422+
d_fn42.execute(l, x_, &d_l);
3423+
printf("{%.2f, %.2f, %.2f, %.2f}", d_l.w[0].z, d_l.w[1].z, d_l.w[2].z, d_l.w[3].z); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00}
33553424
}

0 commit comments

Comments
 (0)