Skip to content

Commit 8943593

Browse files
authored
Merge pull request #133 from csmith763/main
Modifications to ADScalar Constructors
2 parents e7f4dcf + 73273ec commit 8943593

File tree

5 files changed

+206
-27
lines changed

5 files changed

+206
-27
lines changed

include/ad/a2dmatinv.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ namespace A2D {
2222
*/
2323

2424
template <typename T, int N>
25-
A2D_FUNCTION void MatInv(const Mat<T, N, N>& A, Mat<T, N, N>& Ainv) {
25+
A2D_FUNCTION void MatInv(const Mat<T, N, N> &A, Mat<T, N, N> &Ainv) {
2626
MatInvCore<T, N>(get_data(A), get_data(Ainv));
2727
}
2828
template <typename T, int N>
29-
A2D_FUNCTION void MatInv(const SymMat<T, N>& S, SymMat<T, N>& Sinv) {
29+
A2D_FUNCTION void MatInv(const SymMat<T, N> &S, SymMat<T, N> &Sinv) {
3030
SymMatInvCore<T, N>(get_data(S), get_data(Sinv));
3131
}
3232

@@ -56,7 +56,7 @@ class MatInvExpr {
5656
static constexpr MatOp NORMAL = MatOp::NORMAL;
5757
static constexpr MatOp TRANSPOSE = MatOp::TRANSPOSE;
5858

59-
A2D_FUNCTION MatInvExpr(Atype& A, Btype& Ainv) : A(A), Ainv(Ainv) {}
59+
A2D_FUNCTION MatInvExpr(Atype &A, Btype &Ainv) : A(A), Ainv(Ainv) {}
6060

6161
A2D_FUNCTION void eval() { MatInvCore<T, N>(get_data(A), get_data(Ainv)); }
6262

@@ -67,8 +67,8 @@ class MatInvExpr {
6767
static_assert(
6868
!(order == ADorder::FIRST and forder == ADorder::SECOND),
6969
"Can't perform second order forward with first order objects");
70-
constexpr ADseed seed = conditional_value < ADseed,
71-
forder == ADorder::FIRST, ADseed::b, ADseed::p > ::value;
70+
constexpr ADseed seed = conditional_value<ADseed, forder == ADorder::FIRST,
71+
ADseed::b, ADseed::p>::value;
7272

7373
T temp[N * N];
7474
MatMatMultCore<T, N, N, N, N, N, N, NORMAL, NORMAL>(
@@ -120,17 +120,17 @@ class MatInvExpr {
120120
T(-1.0), temp, get_data(Ainv), GetSeed<ADseed::h>::get_data(A));
121121
}
122122

123-
Atype& A;
124-
Btype& Ainv;
123+
Atype &A;
124+
Btype &Ainv;
125125
};
126126

127127
template <class Atype, class Btype>
128-
A2D_FUNCTION auto MatInv(ADObj<Atype>& A, ADObj<Btype>& Ainv) {
128+
A2D_FUNCTION auto MatInv(ADObj<Atype> &A, ADObj<Btype> &Ainv) {
129129
return MatInvExpr<ADObj<Atype>, ADObj<Btype>>(A, Ainv);
130130
}
131131

132132
template <class Atype, class Btype>
133-
A2D_FUNCTION auto MatInv(A2DObj<Atype>& A, A2DObj<Btype>& Ainv) {
133+
A2D_FUNCTION auto MatInv(A2DObj<Atype> &A, A2DObj<Btype> &Ainv) {
134134
return MatInvExpr<A2DObj<Atype>, A2DObj<Btype>>(A, Ainv);
135135
}
136136

@@ -150,7 +150,7 @@ class MatInvTest : public A2DTest<T, Mat<T, N, N>, Mat<T, N, N>> {
150150
}
151151

152152
// Evaluate the matrix-matrix product
153-
Output eval(const Input& x) {
153+
Output eval(const Input &x) {
154154
Mat<T, N, N> A;
155155
Mat<T, N, N> B;
156156
x.get_values(A);
@@ -159,7 +159,7 @@ class MatInvTest : public A2DTest<T, Mat<T, N, N>, Mat<T, N, N>> {
159159
}
160160

161161
// Compute the derivative
162-
void deriv(const Output& seed, const Input& x, Input& g) {
162+
void deriv(const Output &seed, const Input &x, Input &g) {
163163
ADObj<Mat<T, N, N>> A;
164164
ADObj<Mat<T, N, N>> B;
165165

@@ -171,8 +171,8 @@ class MatInvTest : public A2DTest<T, Mat<T, N, N>, Mat<T, N, N>> {
171171
}
172172

173173
// Compute the second-derivative
174-
void hprod(const Output& seed, const Output& hval, const Input& x,
175-
const Input& p, Input& h) {
174+
void hprod(const Output &seed, const Output &hval, const Input &x,
175+
const Input &p, Input &h) {
176176
A2DObj<Mat<T, N, N>> A;
177177
A2DObj<Mat<T, N, N>> B;
178178

include/ad/a2dobj.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -863,39 +863,33 @@ struct __get_object_numeric_type<ADScalar<A2D_complex_t<double>, N>> {
863863
using type = ADScalar<A2D_complex_t<double>, N>;
864864
};
865865

866-
template <typename T, int N,
867-
std::enable_if_t<is_numeric_type<T>::value, bool> = true>
866+
template <typename T, int N>
868867
ADScalar<T, N>& get_data(ADScalar<T, N>& value) {
869868
return value;
870869
}
871870

872-
template <typename T, int N,
873-
std::enable_if_t<is_numeric_type<T>::value, bool> = true>
871+
template <typename T, int N>
874872
const ADScalar<T, N>& get_data(const ADScalar<T, N>& value) {
875873
return value;
876874
}
877875

878-
template <typename T, int N,
879-
std::enable_if_t<is_numeric_type<T>::value, bool> = true>
876+
template <typename T, int N>
880877
A2D_FUNCTION ADScalar<T, N>& get_data(ADObj<ADScalar<T, N>>& value) {
881878
return value.value();
882879
}
883880

884-
template <typename T, int N,
885-
std::enable_if_t<is_numeric_type<T>::value, bool> = true>
881+
template <typename T, int N>
886882
A2D_FUNCTION const ADScalar<T, N>& get_data(
887883
const ADObj<ADScalar<T, N>>& value) {
888884
return value.value();
889885
}
890886

891-
template <typename T, int N,
892-
std::enable_if_t<is_numeric_type<T>::value, bool> = true>
887+
template <typename T, int N>
893888
A2D_FUNCTION ADScalar<T, N>& get_data(A2DObj<ADScalar<T, N>>& value) {
894889
return value.value();
895890
}
896891

897-
template <typename T, int N,
898-
std::enable_if_t<is_numeric_type<T>::value, bool> = true>
892+
template <typename T, int N>
899893
A2D_FUNCTION const ADScalar<T, N>& get_data(
900894
const A2DObj<ADScalar<T, N>>& value) {
901895
return value.value();

include/adscalar.h

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,69 @@
88

99
namespace A2D {
1010

11+
// Detections for scalar types
12+
template <class>
13+
struct is_adscalar : std::false_type {};
14+
template <class U, int M>
15+
struct is_adscalar<ADScalar<U, M>> : std::true_type {};
16+
template <class X>
17+
inline constexpr bool is_adscalar_v = is_adscalar<X>::value;
18+
1119
template <class T, int N>
1220
class ADScalar {
1321
public:
1422
using type = T;
1523

1624
A2D_FUNCTION ADScalar() {}
1725

26+
// Value constructor (sets a value, zeros derivatives)
1827
template <typename R, typename = std::enable_if_t<is_scalar_type<R>::value>>
1928
A2D_FUNCTION ADScalar(const R value) : value(value), deriv{0.0} {}
2029

21-
template <typename R, typename = std::enable_if_t<is_scalar_type<R>::value>>
22-
A2D_FUNCTION ADScalar(const R value, const T d[]) : value(value) {
30+
// Value and derivative constructor (sets both, and works regardless of the
31+
// type T)
32+
A2D_FUNCTION ADScalar(const T &value, const T d[]) : value(value) {
2333
for (int i = 0; i < N; i++) {
2434
deriv[i] = d[i];
2535
}
2636
}
2737

38+
// Scalar-only version for ADScalar constructor for a common occasion
39+
template <typename R, typename = std::enable_if_t<is_scalar_type<R>::value>>
40+
ADScalar(const R val, const T d[]) : ADScalar(T(val), d) {}
41+
42+
// Copy constructor (sets value and derivatives from another ADScalar)
2843
A2D_FUNCTION ADScalar(const ADScalar<T, N> &r) : value(r.value) {
2944
for (int i = 0; i < N; i++) {
3045
deriv[i] = r.deriv[i];
3146
}
3247
}
3348

49+
// Conversion constructor
50+
// - disabled when R == T, which forces copy constructor for same type copies
51+
// - enabled when conversion or lifting makes sense
52+
template <typename R, std::enable_if_t<!std::is_same_v<R, T> &&
53+
(std::is_convertible_v<R, T> ||
54+
is_adscalar_v<T>),
55+
int> = 0>
56+
explicit ADScalar(const ADScalar<R, N> &r) {
57+
if constexpr (is_adscalar_v<T>) {
58+
// Lifting: occurs if T is an ADScalar<...> type
59+
// copy the entire ADScalar r into 'value' for the current ADScalar
60+
value = r; // invokes inner ADScalar's conversion/copy
61+
for (int i = 0; i < N; ++i) {
62+
deriv[i] = T(0.0); // outer derivatives zeroed
63+
}
64+
} else {
65+
// Componentwise: T is a plain scalar-like type (e.g. double)
66+
value = static_cast<T>(r.value);
67+
for (int i = 0; i < N; ++i) {
68+
deriv[i] = static_cast<T>(r.deriv[i]);
69+
}
70+
}
71+
}
72+
73+
// Assignment operator
3474
template <typename R, typename = std::enable_if_t<is_scalar_type<R>::value>>
3575
A2D_FUNCTION inline ADScalar<T, N> &operator=(const R &r) {
3676
value = r;

tests/ad/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_executable(test_ad_expressions test_ad_expressions.cpp)
33
add_executable(test_a2dmat test_a2dmat.cpp)
44
add_executable(test_a2dmatinv test_a2dmatinv.cpp)
55
add_executable(test_a2dmatdet test_a2dmatdet.cpp)
6+
add_executable(test_adscalar test_adscalar.cpp)
67

78
target_compile_options(test_ad_expressions PRIVATE -fsanitize=address)
89
target_link_options(test_ad_expressions PRIVATE -fsanitize=address)
@@ -16,6 +17,8 @@ target_include_directories(test_a2dmatinv PRIVATE
1617
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/tests)
1718
target_include_directories(test_a2dmatdet PRIVATE
1819
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/tests)
20+
target_include_directories(test_adscalar PRIVATE
21+
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/tests)
1922

2023
# For tests implmented using gtest, link them to gtest
2124
target_link_libraries(test_a2dmat PRIVATE gtest_main)
@@ -29,5 +32,6 @@ gtest_discover_tests(test_a2dmatdet)
2932

3033
# Add non-gtest tests manually so that ctest could recognize it's a test
3134
add_test(NAME test_ad_expressions COMMAND test_ad_expressions)
35+
add_test(NAME test_adscalar COMMAND test_adscalar)
3236

3337
add_subdirectory(core)

tests/ad/test_adscalar.cpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#include <functional>
2+
#include <iostream>
3+
#include <vector>
4+
5+
#include "a2dcore.h"
6+
7+
using namespace A2D;
8+
9+
template <typename T>
10+
class ADScalarTest : public A2D::Test::A2DTest<T, T, T> {
11+
public:
12+
using Input = VarTuple<T, T>;
13+
using Output = VarTuple<T, T>;
14+
15+
std::string name() override { return "ADScalarTest"; }
16+
17+
// Evaluate f(x) = x^2 using ADScalar
18+
Output eval(const Input &x) override {
19+
T val;
20+
x.get_values(val);
21+
ADScalar<T, 1> ad_x(val);
22+
ADScalar<T, 1> ad_y = ad_x * ad_x;
23+
return MakeVarTuple<T>(ad_y.value);
24+
}
25+
26+
// Compute the derivative: df/dx = 2x using ADScalar
27+
void deriv(const Output &seed, const Input &x, Input &g) override {
28+
T val;
29+
x.get_values(val);
30+
T dy;
31+
seed.get_values(dy);
32+
ADScalar<T, 1> ad_x(val);
33+
ad_x.deriv[0] = dy;
34+
ADScalar<T, 1> ad_y = ad_x * ad_x;
35+
g.set_values(ad_y.deriv[0]);
36+
}
37+
38+
// Compute the second derivative: d^2f/dx^2 = 2 using ADScalar
39+
void hprod(const Output &seed, const Output &hval, const Input &x,
40+
const Input &p, Input &h) override {
41+
T val, pval, dy, ddy;
42+
x.get_values(val);
43+
p.get_values(pval);
44+
seed.get_values(dy);
45+
hval.get_values(ddy);
46+
ADScalar<T, 1> ad_x(val);
47+
ad_x.deriv[0] = pval;
48+
ADScalar<T, 1> ad_y = ad_x * ad_x;
49+
// Second derivative: d^2f/dx^2 * p = 2 * pval
50+
h.set_values(2.0 * pval * dy + 2.0 * ddy * val);
51+
}
52+
};
53+
54+
template <typename T>
55+
class ADScalarNestedTest : public A2D::Test::A2DTest<T, T, T> {
56+
public:
57+
using Input = VarTuple<T, T>;
58+
using Output = VarTuple<T, T>;
59+
60+
std::string name() override { return "ADScalarNestedTest"; }
61+
62+
// Evaluate f(x) = x^2 using nested ADScalar
63+
Output eval(const Input &x) override {
64+
T val;
65+
x.get_values(val);
66+
ADScalar<T, 1> ad_x(val);
67+
ADScalar<ADScalar<T, 1>, 1> nested_x(ad_x);
68+
ADScalar<ADScalar<T, 1>, 1> nested_y = nested_x * nested_x;
69+
return MakeVarTuple<T>(nested_y.value.value); // Unwrap both layers
70+
}
71+
72+
// Compute the derivative: df/dx = 2x using nested ADScalar
73+
void deriv(const Output &seed, const Input &x, Input &g) override {
74+
T val;
75+
x.get_values(val);
76+
T dy;
77+
seed.get_values(dy);
78+
79+
ADScalar<T, 1> ad_x(val);
80+
ad_x.deriv[0] = dy;
81+
82+
ADScalar<ADScalar<T, 1>, 1> nested_x(ad_x);
83+
84+
ADScalar<ADScalar<T, 1>, 1> nested_y = nested_x * nested_x;
85+
86+
g.set_values(
87+
nested_y.value.deriv[0]); // Unwrap derivative from inner ADScalar
88+
}
89+
90+
// Compute the second derivative: d^2f/dx^2 = 2 using nested ADScalar
91+
void hprod(const Output &seed, const Output &hval, const Input &x,
92+
const Input &p, Input &h) override {
93+
T val, pval, dy, ddy;
94+
x.get_values(val);
95+
p.get_values(pval);
96+
seed.get_values(dy);
97+
hval.get_values(ddy);
98+
ADScalar<T, 1> ad_x(val);
99+
ad_x.deriv[0] = pval;
100+
ADScalar<ADScalar<T, 1>, 1> nested_x(ad_x);
101+
nested_x.deriv[0] = ADScalar<T, 1>(0.0);
102+
ADScalar<ADScalar<T, 1>, 1> nested_y = nested_x * nested_x;
103+
// Second derivative: d^2f/dx^2 * p = 2 * pval
104+
h.set_values(2.0 * pval * dy + 2.0 * ddy * val);
105+
}
106+
};
107+
bool ADScalarTestAll(bool component, bool write_output) {
108+
using Tc = A2D_complex_t<double>;
109+
110+
bool passed = true;
111+
ADScalarTest<Tc> test;
112+
test.set_step_size(1e-30);
113+
114+
ADScalarNestedTest<Tc> nested_test;
115+
nested_test.set_step_size(1e-30);
116+
117+
passed = passed && Run(test, component, write_output);
118+
passed = passed && Run(nested_test, component, write_output);
119+
120+
return passed;
121+
}
122+
123+
int main(int argc, char *argv[]) {
124+
bool component = false; // Default to a projection test
125+
bool write_output = false; // Don't write output;
126+
127+
// Check for the write_output flag
128+
for (int i = 0; i < argc; i++) {
129+
std::string str(argv[i]);
130+
if (str.compare("--write_output") == 0) {
131+
write_output = true;
132+
}
133+
if (str.compare("--component") == 0) {
134+
component = true;
135+
}
136+
}
137+
138+
bool passed = ADScalarTestAll(component, write_output);
139+
140+
return passed ? 0 : 1;
141+
}

0 commit comments

Comments
 (0)