diff --git a/include/ad/a2dmat.h b/include/ad/a2dmat.h index 9c07bb3..be91920 100644 --- a/include/ad/a2dmat.h +++ b/include/ad/a2dmat.h @@ -173,6 +173,18 @@ class SymMat { T A[MAT_SIZE]; }; +template +struct is_a2d_matrix : std::false_type {}; + +template +struct is_a2d_matrix> : std::true_type {}; + +template +struct is_a2d_sym_matrix : std::false_type {}; + +template +struct is_a2d_sym_matrix> : std::true_type {}; + } // namespace A2D #endif // A2D_MAT_H diff --git a/include/ad/a2dobj.h b/include/ad/a2dobj.h index ff48139..28159dd 100644 --- a/include/ad/a2dobj.h +++ b/include/ad/a2dobj.h @@ -53,6 +53,12 @@ class ADObj : public ADExpr, T> { } // Initialize with both values + template ::value, bool> = true> + A2D_FUNCTION ADObj(T& A, T& Ab) : A(A), Ab(Ab) {} + + template ::value, bool> = true> A2D_FUNCTION ADObj(const T& A, const T& Ab) : A(A), Ab(Ab) {} // Evaluation and derivatives @@ -73,6 +79,39 @@ class ADObj : public ADExpr, T> { A2D_FUNCTION T& bvalue() { return Ab; } A2D_FUNCTION const T& bvalue() const { return Ab; } + template ::type>::value, + bool> = true> + A2D_FUNCTION ADObj operator[](const I i) { + return ADObj(A[i], Ab[i]); + } + + template ::type>::value, + bool> = true> + A2D_FUNCTION ADObj operator()(const I i) { + return ADObj(A[i], Ab[i]); + } + + template ::type>::value, + bool> = true> + A2D_FUNCTION ADObj operator()(const I i, const I j) { + return ADObj(A(i, j), Ab(i, j)); + } + + template < + typename I, typename U = T, + std::enable_if_t< + is_a2d_sym_matrix::type>::value, + bool> = true> + A2D_FUNCTION ADObj operator()(const I i, const I j) { + return ADObj(A(i, j), Ab(i, j)); + } + private: T A; // Object T Ab; // Reverse mode derivative value @@ -136,6 +175,13 @@ class A2DObj : public A2DExpr, T> { Ah = type(0.0); } } + template ::value, bool> = true> + A2D_FUNCTION A2DObj(T& A, T& Ab, T& Ap, T& Ah) + : A(A), Ab(Ab), Ap(Ap), Ah(Ah) {} + + template ::value, bool> = true> A2D_FUNCTION A2DObj(const T& A, const T& Ab, const T& Ap, const T& Ah) : A(A), Ab(Ab), Ap(Ap), Ah(Ah) {} @@ -168,6 +214,39 @@ class A2DObj : public A2DExpr, T> { A2D_FUNCTION T& hvalue() { return Ah; } A2D_FUNCTION const T& hvalue() const { return Ah; } + template ::type>::value, + bool> = true> + A2D_FUNCTION A2DObj operator[](const I i) { + return A2DObj(A[i], Ab[i], Ap[i], Ah[i]); + } + + template ::type>::value, + bool> = true> + A2D_FUNCTION A2DObj operator()(const I i) { + return A2DObj(A[i], Ab[i], Ap[i], Ah[i]); + } + + template ::type>::value, + bool> = true> + A2D_FUNCTION A2DObj operator()(const I i, const I j) { + return A2DObj(A(i, j), Ab(i, j), Ap(i, j), Ah(i, j)); + } + + template < + typename I, typename U = T, + std::enable_if_t< + is_a2d_sym_matrix::type>::value, + bool> = true> + A2D_FUNCTION ADObj operator()(const I i, const I j) { + return A2DObj(A(i, j), Ab(i, j), Ap(i, j), Ah(i, j)); + } + private: T A; // Object T Ab; // Reverse mode derivative value diff --git a/include/ad/a2dscalarops.h b/include/ad/a2dscalarops.h index 3271257..4542796 100644 --- a/include/ad/a2dscalarops.h +++ b/include/ad/a2dscalarops.h @@ -40,10 +40,48 @@ class EvalExpr { }; template +class EvalExprRef { + public: + A2D_FUNCTION EvalExprRef(Expr&& expr, ADObj out) + : expr(a2d_forward(expr)), out(out) {} + + A2D_FUNCTION void eval() { + expr.eval(); + out.value() = expr.value(); + } + A2D_FUNCTION void bzero() { + out.bzero(); + expr.bzero(); + } + template + A2D_FUNCTION void forward() { + static_assert(forder == ADorder::FIRST, + "EvalExprRef only works for first-order AD"); + expr.forward(); + out.bvalue() = expr.bvalue(); + } + A2D_FUNCTION void reverse() { + expr.bvalue() = out.bvalue(); + expr.reverse(); + } + + private: + Expr expr; + ADObj out; +}; + +template ::value, bool> = true> A2D_FUNCTION auto Eval(Expr&& expr, ADObj& out) { return EvalExpr(a2d_forward(expr), out); } +template ::value, bool> = true> +A2D_FUNCTION auto Eval(Expr&& expr, ADObj out) { + return EvalExprRef(a2d_forward(expr), out); +} + template class EvalExpr2 { public: @@ -84,10 +122,56 @@ class EvalExpr2 { }; template +class EvalExprRef2 { + public: + A2D_FUNCTION EvalExprRef2(Expr&& expr, A2DObj out) + : expr(a2d_forward(expr)), out(out) {} + + A2D_FUNCTION void eval() { + expr.eval(); + out.value() = expr.value(); + } + A2D_FUNCTION void bzero() { + out.bzero(); + expr.bzero(); + } + A2D_FUNCTION void reverse() { + expr.bvalue() += out.bvalue(); + expr.reverse(); + } + template + A2D_FUNCTION void forward() { + static_assert(forder == ADorder::SECOND, + "EvalExprRef2 only works for second-order AD"); + expr.hforward(); + out.pvalue() = expr.pvalue(); + } + A2D_FUNCTION void hzero() { + out.hzero(); + expr.hzero(); + } + A2D_FUNCTION void hreverse() { + expr.hvalue() += out.hvalue(); + expr.hreverse(); + } + + private: + Expr expr; + A2DObj out; +}; + +template ::value, bool> = true> A2D_FUNCTION auto Eval(Expr&& expr, A2DObj& out) { return EvalExpr2(a2d_forward(expr), out); } +template ::value, bool> = true> +A2D_FUNCTION auto Eval(Expr&& expr, A2DObj out) { + return EvalExprRef2(a2d_forward(expr), out); +} + namespace Test { template class ScalarTest : public A2DTest { diff --git a/include/ad/a2dvartuple.h b/include/ad/a2dvartuple.h index 877cb5c..21d4357 100644 --- a/include/ad/a2dvartuple.h +++ b/include/ad/a2dvartuple.h @@ -233,8 +233,23 @@ class VarTuple : public VarTupleBase { private: VarTupleObj var; + + template + friend auto& get(VarTuple&); + template + friend auto& get(const VarTuple&); }; +template +A2D_FUNCTION auto& get(VarTuple& t) { + return a2d_get(t.var); +} + +template +A2D_FUNCTION auto& get(const VarTuple& t) { + return a2d_get(t.var); +} + template A2D_FUNCTION auto MakeVarTuple(Vars&... s) { return VarTuple(s...); @@ -293,8 +308,23 @@ class TieTuple : public VarTupleBase { private: VarTupleObj var; + + template + friend auto get(TieTuple&); + template + friend auto get(const TieTuple&); }; +template +A2D_FUNCTION auto get(TieTuple& t) { + return a2d_get(t.var); +} + +template +A2D_FUNCTION auto get(const TieTuple& t) { + return a2d_get(t.var); +} + template A2D_FUNCTION auto MakeTieTuple(Vars&... s) { return TieTuple(s...); diff --git a/include/ad/a2dvec.h b/include/ad/a2dvec.h index 81a5ade..f3fed61 100644 --- a/include/ad/a2dvec.h +++ b/include/ad/a2dvec.h @@ -65,5 +65,12 @@ class Vec { T V[N]; }; +template +struct is_a2d_vector : std::false_type {}; + +template +struct is_a2d_vector> : std::true_type {}; + } // namespace A2D + #endif // A2D_VEC_H