Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/ad/a2dmat.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ class SymMat {
T A[MAT_SIZE];
};

template <typename T>
struct is_a2d_matrix : std::false_type {};

template <typename U, int N, int M>
struct is_a2d_matrix<Mat<U, N, M>> : std::true_type {};

template <typename T>
struct is_a2d_sym_matrix : std::false_type {};

template <typename U, int N>
struct is_a2d_sym_matrix<SymMat<U, N>> : std::true_type {};

} // namespace A2D

#endif // A2D_MAT_H
79 changes: 79 additions & 0 deletions include/ad/a2dobj.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class ADObj : public ADExpr<ADObj<T>, T> {
}

// Initialize with both values
template <typename U = T,
std::enable_if_t<std::is_reference<U>::value, bool> = true>
A2D_FUNCTION ADObj(T& A, T& Ab) : A(A), Ab(Ab) {}

template <typename U = T,
std::enable_if_t<!std::is_reference<U>::value, bool> = true>
A2D_FUNCTION ADObj(const T& A, const T& Ab) : A(A), Ab(Ab) {}

// Evaluation and derivatives
Expand All @@ -73,6 +79,39 @@ class ADObj : public ADExpr<ADObj<T>, T> {
A2D_FUNCTION T& bvalue() { return Ab; }
A2D_FUNCTION const T& bvalue() const { return Ab; }

template <typename I, typename U = T,
std::enable_if_t<
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION ADObj<type&> operator[](const I i) {
return ADObj<type&>(A[i], Ab[i]);
}

template <typename I, typename U = T,
std::enable_if_t<
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION ADObj<type&> operator()(const I i) {
return ADObj<type&>(A[i], Ab[i]);
}

template <typename I, typename U = T,
std::enable_if_t<
is_a2d_matrix<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION ADObj<type&> operator()(const I i, const I j) {
return ADObj<type&>(A(i, j), Ab(i, j));
}

template <
typename I, typename U = T,
std::enable_if_t<
is_a2d_sym_matrix<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION ADObj<type&> operator()(const I i, const I j) {
return ADObj<type&>(A(i, j), Ab(i, j));
}

private:
T A; // Object
T Ab; // Reverse mode derivative value
Expand Down Expand Up @@ -136,6 +175,13 @@ class A2DObj : public A2DExpr<A2DObj<T>, T> {
Ah = type(0.0);
}
}
template <typename U = T,
std::enable_if_t<std::is_reference<U>::value, bool> = true>
A2D_FUNCTION A2DObj(T& A, T& Ab, T& Ap, T& Ah)
: A(A), Ab(Ab), Ap(Ap), Ah(Ah) {}

template <typename U = T,
std::enable_if_t<!std::is_reference<U>::value, bool> = true>
A2D_FUNCTION A2DObj(const T& A, const T& Ab, const T& Ap, const T& Ah)
: A(A), Ab(Ab), Ap(Ap), Ah(Ah) {}

Expand Down Expand Up @@ -168,6 +214,39 @@ class A2DObj : public A2DExpr<A2DObj<T>, T> {
A2D_FUNCTION T& hvalue() { return Ah; }
A2D_FUNCTION const T& hvalue() const { return Ah; }

template <typename I, typename U = T,
std::enable_if_t<
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION A2DObj<type&> operator[](const I i) {
return A2DObj<type&>(A[i], Ab[i], Ap[i], Ah[i]);
}

template <typename I, typename U = T,
std::enable_if_t<
is_a2d_vector<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION A2DObj<type&> operator()(const I i) {
return A2DObj<type&>(A[i], Ab[i], Ap[i], Ah[i]);
}

template <typename I, typename U = T,
std::enable_if_t<
is_a2d_matrix<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION A2DObj<type&> operator()(const I i, const I j) {
return A2DObj<type&>(A(i, j), Ab(i, j), Ap(i, j), Ah(i, j));
}

template <
typename I, typename U = T,
std::enable_if_t<
is_a2d_sym_matrix<typename remove_const_and_refs<U>::type>::value,
bool> = true>
A2D_FUNCTION ADObj<type&> operator()(const I i, const I j) {
return A2DObj<type&>(A(i, j), Ab(i, j), Ap(i, j), Ah(i, j));
}

private:
T A; // Object
T Ab; // Reverse mode derivative value
Expand Down
84 changes: 84 additions & 0 deletions include/ad/a2dscalarops.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,48 @@ class EvalExpr {
};

template <class Expr, class T>
class EvalExprRef {
public:
A2D_FUNCTION EvalExprRef(Expr&& expr, ADObj<T&> out)
: expr(a2d_forward<Expr>(expr)), out(out) {}

A2D_FUNCTION void eval() {
expr.eval();
out.value() = expr.value();
}
A2D_FUNCTION void bzero() {
out.bzero();
expr.bzero();
}
template <ADorder forder>
A2D_FUNCTION void forward() {
static_assert(forder == ADorder::FIRST,
"EvalExprRef only works for first-order AD");
expr.forward();
out.bvalue() = expr.bvalue();
}
A2D_FUNCTION void reverse() {
expr.bvalue() = out.bvalue();
expr.reverse();
}

private:
Expr expr;
ADObj<T&> out;
};

template <class Expr, class T,
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
A2D_FUNCTION auto Eval(Expr&& expr, ADObj<T>& out) {
return EvalExpr<Expr, T>(a2d_forward<Expr>(expr), out);
}

template <class Expr, class T,
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
A2D_FUNCTION auto Eval(Expr&& expr, ADObj<T&> out) {
return EvalExprRef<Expr, T>(a2d_forward<Expr>(expr), out);
}

template <class Expr, class T>
class EvalExpr2 {
public:
Expand Down Expand Up @@ -84,10 +122,56 @@ class EvalExpr2 {
};

template <class Expr, class T>
class EvalExprRef2 {
public:
A2D_FUNCTION EvalExprRef2(Expr&& expr, A2DObj<T&> out)
: expr(a2d_forward<Expr>(expr)), out(out) {}

A2D_FUNCTION void eval() {
expr.eval();
out.value() = expr.value();
}
A2D_FUNCTION void bzero() {
out.bzero();
expr.bzero();
}
A2D_FUNCTION void reverse() {
expr.bvalue() += out.bvalue();
expr.reverse();
}
template <ADorder forder>
A2D_FUNCTION void forward() {
static_assert(forder == ADorder::SECOND,
"EvalExprRef2 only works for second-order AD");
expr.hforward();
out.pvalue() = expr.pvalue();
}
A2D_FUNCTION void hzero() {
out.hzero();
expr.hzero();
}
A2D_FUNCTION void hreverse() {
expr.hvalue() += out.hvalue();
expr.hreverse();
}

private:
Expr expr;
A2DObj<T&> out;
};

template <class Expr, class T,
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
A2D_FUNCTION auto Eval(Expr&& expr, A2DObj<T>& out) {
return EvalExpr2<Expr, T>(a2d_forward<Expr>(expr), out);
}

template <class Expr, class T,
std::enable_if_t<!std::is_reference<T>::value, bool> = true>
A2D_FUNCTION auto Eval(Expr&& expr, A2DObj<T&> out) {
return EvalExprRef2<Expr, T>(a2d_forward<Expr>(expr), out);
}

namespace Test {
template <typename T>
class ScalarTest : public A2DTest<T, T, T, T> {
Expand Down
30 changes: 30 additions & 0 deletions include/ad/a2dvartuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,23 @@ class VarTuple : public VarTupleBase<T, Vars...> {

private:
VarTupleObj var;

template <int index, typename T1, class... Vars1>
friend auto& get(VarTuple<T1, Vars1...>&);
template <int index, typename T1, class... Vars1>
friend auto& get(const VarTuple<T1, Vars1...>&);
};

template <int index, typename T, class... Vars>
A2D_FUNCTION auto& get(VarTuple<T, Vars...>& t) {
return a2d_get<index>(t.var);
}

template <int index, typename T, class... Vars>
A2D_FUNCTION auto& get(const VarTuple<T, Vars...>& t) {
return a2d_get<index>(t.var);
}

template <typename T, class... Vars>
A2D_FUNCTION auto MakeVarTuple(Vars&... s) {
return VarTuple<T, Vars...>(s...);
Expand Down Expand Up @@ -293,8 +308,23 @@ class TieTuple : public VarTupleBase<T, Vars...> {

private:
VarTupleObj var;

template <int index, typename T1, class... Vars1>
friend auto get(TieTuple<T1, Vars1...>&);
template <int index, typename T1, class... Vars1>
friend auto get(const TieTuple<T1, Vars1...>&);
};

template <int index, typename T, class... Vars>
A2D_FUNCTION auto get(TieTuple<T, Vars...>& t) {
return a2d_get<index>(t.var);
}

template <int index, typename T, class... Vars>
A2D_FUNCTION auto get(const TieTuple<T, Vars...>& t) {
return a2d_get<index>(t.var);
}

template <typename T, class... Vars>
A2D_FUNCTION auto MakeTieTuple(Vars&... s) {
return TieTuple<T, Vars...>(s...);
Expand Down
7 changes: 7 additions & 0 deletions include/ad/a2dvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,12 @@ class Vec {
T V[N];
};

template <typename T>
struct is_a2d_vector : std::false_type {};

template <typename U, int N>
struct is_a2d_vector<Vec<U, N>> : std::true_type {};

} // namespace A2D

#endif // A2D_VEC_H