Skip to content

Commit 646a8de

Browse files
committed
added component-wise access for Eval operation
1 parent 741abe6 commit 646a8de

File tree

5 files changed

+212
-0
lines changed

5 files changed

+212
-0
lines changed

include/ad/a2dmat.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,18 @@ class SymMat {
173173
T A[MAT_SIZE];
174174
};
175175

176+
template <typename T>
177+
struct is_a2d_matrix : std::false_type {};
178+
179+
template <typename U, int N, int M>
180+
struct is_a2d_matrix<Mat<U, N, M>> : std::true_type {};
181+
182+
template <typename T>
183+
struct is_a2d_sym_matrix : std::false_type {};
184+
185+
template <typename U, int N>
186+
struct is_a2d_sym_matrix<SymMat<U, N>> : std::true_type {};
187+
176188
} // namespace A2D
177189

178190
#endif // A2D_MAT_H

include/ad/a2dobj.h

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ class ADObj : public ADExpr<ADObj<T>, T> {
5353
}
5454

5555
// Initialize with both values
56+
template <typename U = T,
57+
std::enable_if_t<std::is_reference<U>::value, bool> = true>
58+
A2D_FUNCTION ADObj(T& A, T& Ab) : A(A), Ab(Ab) {}
59+
60+
template <typename U = T,
61+
std::enable_if_t<!std::is_reference<U>::value, bool> = true>
5662
A2D_FUNCTION ADObj(const T& A, const T& Ab) : A(A), Ab(Ab) {}
5763

5864
// Evaluation and derivatives
@@ -73,6 +79,39 @@ class ADObj : public ADExpr<ADObj<T>, T> {
7379
A2D_FUNCTION T& bvalue() { return Ab; }
7480
A2D_FUNCTION const T& bvalue() const { return Ab; }
7581

82+
template <typename I, typename U = T,
83+
std::enable_if_t<
84+
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
85+
bool> = true>
86+
A2D_FUNCTION ADObj<type&> operator[](const I i) {
87+
return ADObj<type&>(A[i], Ab[i]);
88+
}
89+
90+
template <typename I, typename U = T,
91+
std::enable_if_t<
92+
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
93+
bool> = true>
94+
A2D_FUNCTION ADObj<type&> operator()(const I i) {
95+
return ADObj<type&>(A[i], Ab[i]);
96+
}
97+
98+
template <typename I, typename U = T,
99+
std::enable_if_t<
100+
is_a2d_matrix<typename remove_const_and_refs<U>::type>::value,
101+
bool> = true>
102+
A2D_FUNCTION ADObj<type&> operator()(const I i, const I j) {
103+
return ADObj<type&>(A(i, j), Ab(i, j));
104+
}
105+
106+
template <
107+
typename I, typename U = T,
108+
std::enable_if_t<
109+
is_a2d_sym_matrix<typename remove_const_and_refs<U>::type>::value,
110+
bool> = true>
111+
A2D_FUNCTION ADObj<type&> operator()(const I i, const I j) {
112+
return ADObj<type&>(A(i, j), Ab(i, j));
113+
}
114+
76115
private:
77116
T A; // Object
78117
T Ab; // Reverse mode derivative value
@@ -136,6 +175,13 @@ class A2DObj : public A2DExpr<A2DObj<T>, T> {
136175
Ah = type(0.0);
137176
}
138177
}
178+
template <typename U = T,
179+
std::enable_if_t<std::is_reference<U>::value, bool> = true>
180+
A2D_FUNCTION A2DObj(T& A, T& Ab, T& Ap, T& Ah)
181+
: A(A), Ab(Ab), Ap(Ap), Ah(Ah) {}
182+
183+
template <typename U = T,
184+
std::enable_if_t<!std::is_reference<U>::value, bool> = true>
139185
A2D_FUNCTION A2DObj(const T& A, const T& Ab, const T& Ap, const T& Ah)
140186
: A(A), Ab(Ab), Ap(Ap), Ah(Ah) {}
141187

@@ -168,6 +214,39 @@ class A2DObj : public A2DExpr<A2DObj<T>, T> {
168214
A2D_FUNCTION T& hvalue() { return Ah; }
169215
A2D_FUNCTION const T& hvalue() const { return Ah; }
170216

217+
template <typename I, typename U = T,
218+
std::enable_if_t<
219+
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
220+
bool> = true>
221+
A2D_FUNCTION A2DObj<type&> operator[](const I i) {
222+
return A2DObj<type&>(A[i], Ab[i], Ap[i], Ah[i]);
223+
}
224+
225+
template <typename I, typename U = T,
226+
std::enable_if_t<
227+
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
228+
bool> = true>
229+
A2D_FUNCTION A2DObj<type&> operator()(const I i) {
230+
return A2DObj<type&>(A[i], Ab[i], Ap[i], Ah[i]);
231+
}
232+
233+
template <typename I, typename U = T,
234+
std::enable_if_t<
235+
is_a2d_matrix<typename remove_const_and_refs<U>::type>::value,
236+
bool> = true>
237+
A2D_FUNCTION A2DObj<type&> operator()(const I i, const I j) {
238+
return A2DObj<type&>(A(i, j), Ab(i, j), Ap(i, j), Ah(i, j));
239+
}
240+
241+
template <
242+
typename I, typename U = T,
243+
std::enable_if_t<
244+
is_a2d_sym_matrix<typename remove_const_and_refs<U>::type>::value,
245+
bool> = true>
246+
A2D_FUNCTION ADObj<type&> operator()(const I i, const I j) {
247+
return A2DObj<type&>(A(i, j), Ab(i, j), Ap(i, j), Ah(i, j));
248+
}
249+
171250
private:
172251
T A; // Object
173252
T Ab; // Reverse mode derivative value

include/ad/a2dscalarops.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,48 @@ class EvalExpr {
4040
};
4141

4242
template <class Expr, class T>
43+
class EvalExprRef {
44+
public:
45+
A2D_FUNCTION EvalExprRef(Expr&& expr, ADObj<T&> out)
46+
: expr(a2d_forward<Expr>(expr)), out(out) {}
47+
48+
A2D_FUNCTION void eval() {
49+
expr.eval();
50+
out.value() = expr.value();
51+
}
52+
A2D_FUNCTION void bzero() {
53+
out.bzero();
54+
expr.bzero();
55+
}
56+
template <ADorder forder>
57+
A2D_FUNCTION void forward() {
58+
static_assert(forder == ADorder::FIRST,
59+
"EvalExprRef only works for first-order AD");
60+
expr.forward();
61+
out.bvalue() = expr.bvalue();
62+
}
63+
A2D_FUNCTION void reverse() {
64+
expr.bvalue() = out.bvalue();
65+
expr.reverse();
66+
}
67+
68+
private:
69+
Expr expr;
70+
ADObj<T&> out;
71+
};
72+
73+
template <class Expr, class T,
74+
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
4375
A2D_FUNCTION auto Eval(Expr&& expr, ADObj<T>& out) {
4476
return EvalExpr<Expr, T>(a2d_forward<Expr>(expr), out);
4577
}
4678

79+
template <class Expr, class T,
80+
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
81+
A2D_FUNCTION auto Eval(Expr&& expr, ADObj<T&> out) {
82+
return EvalExprRef<Expr, T>(a2d_forward<Expr>(expr), out);
83+
}
84+
4785
template <class Expr, class T>
4886
class EvalExpr2 {
4987
public:
@@ -84,10 +122,56 @@ class EvalExpr2 {
84122
};
85123

86124
template <class Expr, class T>
125+
class EvalExprRef2 {
126+
public:
127+
A2D_FUNCTION EvalExprRef2(Expr&& expr, A2DObj<T&> out)
128+
: expr(a2d_forward<Expr>(expr)), out(out) {}
129+
130+
A2D_FUNCTION void eval() {
131+
expr.eval();
132+
out.value() = expr.value();
133+
}
134+
A2D_FUNCTION void bzero() {
135+
out.bzero();
136+
expr.bzero();
137+
}
138+
A2D_FUNCTION void reverse() {
139+
expr.bvalue() += out.bvalue();
140+
expr.reverse();
141+
}
142+
template <ADorder forder>
143+
A2D_FUNCTION void forward() {
144+
static_assert(forder == ADorder::SECOND,
145+
"EvalExprRef2 only works for second-order AD");
146+
expr.hforward();
147+
out.pvalue() = expr.pvalue();
148+
}
149+
A2D_FUNCTION void hzero() {
150+
out.hzero();
151+
expr.hzero();
152+
}
153+
A2D_FUNCTION void hreverse() {
154+
expr.hvalue() += out.hvalue();
155+
expr.hreverse();
156+
}
157+
158+
private:
159+
Expr expr;
160+
A2DObj<T&> out;
161+
};
162+
163+
template <class Expr, class T,
164+
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
87165
A2D_FUNCTION auto Eval(Expr&& expr, A2DObj<T>& out) {
88166
return EvalExpr2<Expr, T>(a2d_forward<Expr>(expr), out);
89167
}
90168

169+
template <class Expr, class T,
170+
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
171+
A2D_FUNCTION auto Eval(Expr&& expr, A2DObj<T&> out) {
172+
return EvalExprRef2<Expr, T>(a2d_forward<Expr>(expr), out);
173+
}
174+
91175
namespace Test {
92176
template <typename T>
93177
class ScalarTest : public A2DTest<T, T, T, T> {

include/ad/a2dvartuple.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,23 @@ class VarTuple : public VarTupleBase<T, Vars...> {
233233

234234
private:
235235
VarTupleObj var;
236+
237+
template <int index, typename T1, class... Vars1>
238+
friend auto& get(VarTuple<T1, Vars1...>&);
239+
template <int index, typename T1, class... Vars1>
240+
friend auto& get(const VarTuple<T1, Vars1...>&);
236241
};
237242

243+
template <int index, typename T, class... Vars>
244+
A2D_FUNCTION auto& get(VarTuple<T, Vars...>& t) {
245+
return a2d_get<index>(t.var);
246+
}
247+
248+
template <int index, typename T, class... Vars>
249+
A2D_FUNCTION auto& get(const VarTuple<T, Vars...>& t) {
250+
return a2d_get<index>(t.var);
251+
}
252+
238253
template <typename T, class... Vars>
239254
A2D_FUNCTION auto MakeVarTuple(Vars&... s) {
240255
return VarTuple<T, Vars...>(s...);
@@ -293,8 +308,23 @@ class TieTuple : public VarTupleBase<T, Vars...> {
293308

294309
private:
295310
VarTupleObj var;
311+
312+
template <int index, typename T1, class... Vars1>
313+
friend auto get(TieTuple<T1, Vars1...>&);
314+
template <int index, typename T1, class... Vars1>
315+
friend auto get(const TieTuple<T1, Vars1...>&);
296316
};
297317

318+
template <int index, typename T, class... Vars>
319+
A2D_FUNCTION auto get(TieTuple<T, Vars...>& t) {
320+
return a2d_get<index>(t.var);
321+
}
322+
323+
template <int index, typename T, class... Vars>
324+
A2D_FUNCTION auto get(const TieTuple<T, Vars...>& t) {
325+
return a2d_get<index>(t.var);
326+
}
327+
298328
template <typename T, class... Vars>
299329
A2D_FUNCTION auto MakeTieTuple(Vars&... s) {
300330
return TieTuple<T, Vars...>(s...);

include/ad/a2dvec.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,12 @@ class Vec {
6565
T V[N];
6666
};
6767

68+
template <typename T>
69+
struct is_a2d_vector : std::false_type {};
70+
71+
template <typename U, int N>
72+
struct is_a2d_vector<Vec<U, N>> : std::true_type {};
73+
6874
} // namespace A2D
75+
6976
#endif // A2D_VEC_H

0 commit comments

Comments
 (0)