Skip to content

Commit fcf092f

Browse files
committed
test: Generalize assertion helper for N dimensions
Signed-off-by: Sietze van Buuren <[email protected]>
1 parent e676128 commit fcf092f

File tree

3 files changed

+293
-37
lines changed

3 files changed

+293
-37
lines changed

include/linear_interp.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,16 @@ class LinearInterp3D : public LinearInterpND<T, 3> {
280280
~LinearInterp3D() { }
281281
};
282282

283+
template <typename T>
284+
class LinearInterp4D : public LinearInterpND<T, 4> {
285+
using Vector = std::vector<T>;
286+
using Vector4 = cip::VectorN<T, 4>;
287+
public:
288+
explicit LinearInterp4D(const Vector &x, const Vector &y, const Vector &z, const Vector &w, const Vector4 &f)
289+
: LinearInterpND<T, 4>(f, x, y, z, w)
290+
{}
283291

292+
~LinearInterp4D() { }
293+
};
284294

285295
} // namespace cip

tests/assertion_helpers.hpp

Lines changed: 121 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,147 @@
11
#include <vector>
22
#include <gtest/gtest.h>
3-
3+
#include <array>
4+
#include <tuple>
5+
#include <type_traits>
6+
#include <sstream>
7+
#include <functional>
8+
#include <utility>
49

510
namespace cip {
611

7-
812
using Vector = std::vector<double>;
913
using Vector2 = std::vector<Vector>;
1014
using Vector3 = std::vector<Vector2>;
15+
using Vector4 = std::vector<Vector3>;
1116

1217
constexpr double TOLERANCE_DEFAULT = 5.0e-12;
1318

14-
template <typename T>
15-
testing::AssertionResult Interp1DAssertions(Vector x, Vector f, Vector x_fine, Vector f_fine, double tol=TOLERANCE_DEFAULT) {
16-
T interp(x, f);
17-
for ( auto i = 0; i < x_fine.size(); i++ ) {
18-
auto val = interp.eval(x_fine[i]);
19-
if (!testing::internal::DoubleNearPredFormat("expected", "actual", "tolerance", val, f_fine[i], tol) ) {
20-
return testing::AssertionFailure()
21-
<< "for x = " << x_fine[i] << " expected " << f_fine[i] << " but got " << val;
19+
// Helper class for N-dimensional interpolation assertions
20+
template <typename T, size_t N, typename FuncValT>
21+
class InterpNDAssertions {
22+
private:
23+
// Helper to access nested vector value at specific indices
24+
template <typename VecT>
25+
static double get_value(const VecT& val) {
26+
return val;
27+
}
28+
29+
template <typename VecT, typename... Indices>
30+
static double get_value(const VecT& vec, size_t idx, Indices... rest) {
31+
return get_value(vec[idx], rest...);
32+
}
33+
34+
// Helper for dimension name (x, y, z, w, etc.)
35+
static constexpr char dim_name(size_t i) {
36+
return i < 3 ? 'x' + i : 'w' + (i - 3);
37+
}
38+
39+
// Recursive template to build nested loops over N dimensions
40+
template <size_t Dim>
41+
static testing::AssertionResult check_points(
42+
const T& interp,
43+
const std::array<Vector, N>& fine_grid_vectors,
44+
const FuncValT& f_fine,
45+
std::array<size_t, N>& indices,
46+
std::array<double, N>& coords,
47+
double tol) {
48+
49+
// Base case: we've set up all N dimension indices
50+
if constexpr (Dim == N) {
51+
// Call eval with the coordinates using apply
52+
double val = std::apply([&interp](auto... args) {
53+
return interp.eval(args...);
54+
}, coords);
55+
56+
// Get expected value from nested vector using indices
57+
double expected = std::apply([&f_fine](auto... args) {
58+
return get_value(f_fine, args...);
59+
}, indices);
60+
61+
// Check if the values match within tolerance
62+
if (!testing::internal::DoubleNearPredFormat("expected", "actual", "tolerance", val, expected, tol)) {
63+
testing::AssertionResult failure = testing::AssertionFailure();
64+
65+
// Format dimensions for error message (x=1.23, y=4.56, etc.)
66+
for (size_t i = 0; i < N; ++i) {
67+
failure << (i == 0 ? "for " : ", ") << dim_name(i) << " = " << coords[i];
68+
}
69+
70+
failure << " expected " << expected << " but got " << val;
71+
return failure;
72+
}
73+
74+
return testing::AssertionSuccess();
75+
}
76+
// Recursive case: loop over the current dimension
77+
else {
78+
for (size_t i = 0; i < fine_grid_vectors[Dim].size(); ++i) {
79+
indices[Dim] = i;
80+
coords[Dim] = fine_grid_vectors[Dim][i];
81+
82+
// Recursively process the next dimension
83+
auto result = check_points<Dim + 1>(interp, fine_grid_vectors, f_fine, indices, coords, tol);
84+
if (!result) {
85+
return result;
86+
}
87+
}
88+
return testing::AssertionSuccess();
2289
}
2390
}
24-
return testing::AssertionSuccess();
25-
}
2691

92+
public:
93+
// Main assertion method
94+
static testing::AssertionResult test(
95+
const std::array<Vector, N>& grid_vectors,
96+
const FuncValT& f,
97+
const std::array<Vector, N>& fine_grid_vectors,
98+
const FuncValT& f_fine,
99+
double tol = TOLERANCE_DEFAULT) {
100+
101+
// Create the interpolator with a simple lambda + std::apply
102+
T interp = std::apply([&f](const auto&... grid_args) {
103+
return T(grid_args..., f);
104+
}, grid_vectors);
105+
106+
// Set up indices and coordinates arrays
107+
std::array<size_t, N> indices{};
108+
std::array<double, N> coords{};
109+
110+
// Start the recursive dimension traversal
111+
return check_points<0>(interp, fine_grid_vectors, f_fine, indices, coords, tol);
112+
}
113+
};
27114

115+
// Convenience wrapper functions to maintain backward compatibility
116+
template <typename T>
117+
testing::AssertionResult Interp1DAssertions(Vector x, Vector f, Vector x_fine, Vector f_fine, double tol=TOLERANCE_DEFAULT) {
118+
std::array<Vector, 1> grid_vectors = {x};
119+
std::array<Vector, 1> fine_grid_vectors = {x_fine};
120+
return InterpNDAssertions<T, 1, Vector>::test(grid_vectors, f, fine_grid_vectors, f_fine, tol);
121+
}
28122

29123
template <typename T>
30124
testing::AssertionResult Interp2DAssertions(const Vector &x, const Vector &y, const Vector2 &f, const Vector &x_fine, const Vector &y_fine, const Vector2 &f_fine, double tol=TOLERANCE_DEFAULT) {
31-
T interp2(x, y, f);
32-
for ( auto i = 0; i < x_fine.size(); ++i ) {
33-
for ( auto j = 0; j < y_fine.size(); ++j ) {
34-
auto val = interp2.eval(x_fine[i], y_fine[j]);
35-
if (!testing::internal::DoubleNearPredFormat("expected", "actual", "tolerance", val, f_fine[i][j], tol) ) {
36-
return testing::AssertionFailure()
37-
<< "for x = " << x_fine[i] << ", y = " << y_fine[j] << " expected " << f_fine[i][j] << " but got " << val;
38-
}
39-
}
40-
}
41-
return testing::AssertionSuccess();
125+
std::array<Vector, 2> grid_vectors = {x, y};
126+
std::array<Vector, 2> fine_grid_vectors = {x_fine, y_fine};
127+
return InterpNDAssertions<T, 2, Vector2>::test(grid_vectors, f, fine_grid_vectors, f_fine, tol);
42128
}
43129

44-
45130
template <typename T>
46131
testing::AssertionResult Interp3DAssertions(const Vector &x, const Vector &y, const Vector &z, const Vector3 &f, const Vector &x_fine, const Vector &y_fine, const Vector &z_fine, const Vector3 &f_fine, double tol=TOLERANCE_DEFAULT) {
47-
T interp3(x, y, z, f);
48-
for ( auto i = 0; i < x_fine.size(); ++i ) {
49-
for ( auto j = 0; j < y_fine.size(); ++j ) {
50-
for ( auto k = 0; k < y_fine.size(); ++k ) {
51-
auto val = interp3.eval(x_fine[i], y_fine[j], z_fine[k]);
52-
if (!testing::internal::DoubleNearPredFormat("expected", "actual", "tolerance", val, f_fine[i][j][k], tol) ) {
53-
return testing::AssertionFailure()
54-
<< "for x = " << x_fine[i] << ", y = " << y_fine[j] << ", z = " << z_fine[j] << " expected " << f_fine[i][j][k] << " but got " << val;
55-
}
56-
}
57-
}
58-
}
59-
return testing::AssertionSuccess();
132+
std::array<Vector, 3> grid_vectors = {x, y, z};
133+
std::array<Vector, 3> fine_grid_vectors = {x_fine, y_fine, z_fine};
134+
return InterpNDAssertions<T, 3, Vector3>::test(grid_vectors, f, fine_grid_vectors, f_fine, tol);
60135
}
61136

137+
template <typename T>
138+
testing::AssertionResult Interp4DAssertions(const Vector &x, const Vector &y, const Vector &z, const Vector &w,
139+
const Vector4 &f, const Vector &x_fine, const Vector &y_fine,
140+
const Vector &z_fine, const Vector &w_fine, const Vector4 &f_fine,
141+
double tol=TOLERANCE_DEFAULT) {
142+
std::array<Vector, 4> grid_vectors = {x, y, z, w};
143+
std::array<Vector, 4> fine_grid_vectors = {x_fine, y_fine, z_fine, w_fine};
144+
return InterpNDAssertions<T, 4, Vector4>::test(grid_vectors, f, fine_grid_vectors, f_fine, tol);
145+
}
62146

63147
} // namespace cip

tests/test_linear_interp.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using VectorN1 = cip::VectorN<double, 1>;
88
using VectorN2 = cip::VectorN<double, 2>;
99
using VectorN3 = cip::VectorN<double, 3>;
10+
using VectorN4 = cip::VectorN<double, 4>;
1011
using Span = std::span<const double>;
1112
using Pr = std::pair<size_t, size_t>;
1213

@@ -197,3 +198,164 @@ TEST(TestInterp3D, test_linear_interp_3d) {
197198
ASSERT_TRUE(cip::Interp3DAssertions<cip::LinearInterp3D<double>>(x, y, z, f, x_fine, y_fine, z_fine, f_fine));
198199

199200
}
201+
202+
/**
203+
* Test for 4D linear interpolation
204+
*
205+
* This test demonstrates using the InterpNDAssertions class with N=4 dimensions.
206+
* We use a simple 4D linear function f(x,y,z,w) = a*x + b*y + c*z + d*w + e
207+
* to generate the test data, which should be perfectly interpolated by a linear interpolator.
208+
*/
209+
TEST(TestInterp4D, test_linear_interp_4d) {
210+
// Create grid vectors for each dimension
211+
cip::Vector x = { 0.0, 1.0, 2.0 };
212+
cip::Vector y = { 0.0, 1.0, 2.0 };
213+
cip::Vector z = { 0.0, 1.0, 2.0 };
214+
cip::Vector w = { 0.0, 1.0, 2.0 };
215+
216+
// Define a linear function coefficients: f(x,y,z,w) = 2*x + 3*y - z + 0.5*w + 1
217+
const double a = 2.0;
218+
const double b = 3.0;
219+
const double c = -1.0;
220+
const double d = 0.5;
221+
const double e = 1.0;
222+
223+
// Create 4D function values array
224+
cip::Vector4 f;
225+
f.resize(x.size());
226+
for (size_t i = 0; i < x.size(); ++i) {
227+
f[i].resize(y.size());
228+
for (size_t j = 0; j < y.size(); ++j) {
229+
f[i][j].resize(z.size());
230+
for (size_t k = 0; k < z.size(); ++k) {
231+
f[i][j][k].resize(w.size());
232+
for (size_t l = 0; l < w.size(); ++l) {
233+
// Function: f(x,y,z,w) = 2*x + 3*y - z + 0.5*w + 1
234+
f[i][j][k][l] = a*x[i] + b*y[j] + c*z[k] + d*w[l] + e;
235+
}
236+
}
237+
}
238+
}
239+
240+
// Fine grid for testing interpolation
241+
cip::Vector x_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
242+
cip::Vector y_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
243+
cip::Vector z_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
244+
cip::Vector w_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
245+
246+
// Expected values at fine grid points
247+
cip::Vector4 f_fine;
248+
f_fine.resize(x_fine.size());
249+
for (size_t i = 0; i < x_fine.size(); ++i) {
250+
f_fine[i].resize(y_fine.size());
251+
for (size_t j = 0; j < y_fine.size(); ++j) {
252+
f_fine[i][j].resize(z_fine.size());
253+
for (size_t k = 0; k < z_fine.size(); ++k) {
254+
f_fine[i][j][k].resize(w_fine.size());
255+
for (size_t l = 0; l < w_fine.size(); ++l) {
256+
// Function: f(x,y,z,w) = 2*x + 3*y - z + 0.5*w + 1
257+
f_fine[i][j][k][l] = a*x_fine[i] + b*y_fine[j] + c*z_fine[k] + d*w_fine[l] + e;
258+
}
259+
}
260+
}
261+
}
262+
263+
// Test the 4D interpolation using our InterpNDAssertions class
264+
ASSERT_TRUE(cip::Interp4DAssertions<cip::LinearInterp4D<double>>(
265+
x, y, z, w, f, x_fine, y_fine, z_fine, w_fine, f_fine));
266+
}
267+
268+
// Direct test using the InterpNDAssertions template with N=4
269+
TEST(TestInterp4D, test_direct_nd_assertions) {
270+
// Create grid vectors for each dimension
271+
cip::Vector x = { 0.0, 1.0, 2.0 };
272+
cip::Vector y = { 0.0, 1.0, 2.0 };
273+
cip::Vector z = { 0.0, 1.0, 2.0 };
274+
cip::Vector w = { 0.0, 1.0, 2.0 };
275+
276+
// Create 4D function values array
277+
// Function is f(x,y,z,w) = x + y + z + w
278+
cip::Vector4 f;
279+
f.resize(x.size());
280+
for (size_t i = 0; i < x.size(); ++i) {
281+
f[i].resize(y.size());
282+
for (size_t j = 0; j < y.size(); ++j) {
283+
f[i][j].resize(z.size());
284+
for (size_t k = 0; k < z.size(); ++k) {
285+
f[i][j][k].resize(w.size());
286+
for (size_t l = 0; l < w.size(); ++l) {
287+
// Function: f(x,y,z,w) = x + y + z + w
288+
f[i][j][k][l] = x[i] + y[j] + z[k] + w[l];
289+
}
290+
}
291+
}
292+
}
293+
294+
// Fine grid for testing interpolation
295+
cip::Vector x_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
296+
cip::Vector y_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
297+
cip::Vector z_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
298+
cip::Vector w_fine = { 0.0, 0.5, 1.0, 1.5, 2.0 };
299+
300+
// Expected values at fine grid points
301+
cip::Vector4 f_fine;
302+
f_fine.resize(x_fine.size());
303+
for (size_t i = 0; i < x_fine.size(); ++i) {
304+
f_fine[i].resize(y_fine.size());
305+
for (size_t j = 0; j < y_fine.size(); ++j) {
306+
f_fine[i][j].resize(z_fine.size());
307+
for (size_t k = 0; k < z_fine.size(); ++k) {
308+
f_fine[i][j][k].resize(w_fine.size());
309+
for (size_t l = 0; l < w_fine.size(); ++l) {
310+
// Function: f(x,y,z,w) = x + y + z + w
311+
f_fine[i][j][k][l] = x_fine[i] + y_fine[j] + z_fine[k] + w_fine[l];
312+
}
313+
}
314+
}
315+
}
316+
317+
// Pack grid vectors and fine grid vectors into arrays
318+
std::array<cip::Vector, 4> grid_vectors = {x, y, z, w};
319+
std::array<cip::Vector, 4> fine_grid_vectors = {x_fine, y_fine, z_fine, w_fine};
320+
321+
// Test the 4D interpolation directly using the InterpNDAssertions class
322+
auto result = cip::InterpNDAssertions<cip::LinearInterp4D<double>, 4, cip::Vector4>::test(
323+
grid_vectors, f, fine_grid_vectors, f_fine);
324+
ASSERT_TRUE(result);
325+
}
326+
327+
// Test a non-linear function with 4D interpolation
328+
TEST(TestInterp4D, test_nonlinear_4d) {
329+
// Create grid vectors for each dimension
330+
cip::Vector x = { 0.0, 1.0, 2.0 };
331+
cip::Vector y = { 0.0, 1.0, 2.0 };
332+
cip::Vector z = { 0.0, 1.0, 2.0 };
333+
cip::Vector w = { 0.0, 1.0, 2.0 };
334+
335+
// Create 4D function values array for a quadratic function
336+
// Function is f(x,y,z,w) = x^2 + y^2 + z^2 + w^2
337+
cip::Vector4 f;
338+
f.resize(x.size());
339+
for (size_t i = 0; i < x.size(); ++i) {
340+
f[i].resize(y.size());
341+
for (size_t j = 0; j < y.size(); ++j) {
342+
f[i][j].resize(z.size());
343+
for (size_t k = 0; k < z.size(); ++k) {
344+
f[i][j][k].resize(w.size());
345+
for (size_t l = 0; l < w.size(); ++l) {
346+
// Quadratic function
347+
f[i][j][k][l] = x[i]*x[i] + y[j]*y[j] + z[k]*z[k] + w[l]*w[l];
348+
}
349+
}
350+
}
351+
}
352+
353+
// Create interpolator
354+
cip::LinearInterp4D<double> interp(x, y, z, w, f);
355+
356+
// Test a few specific points without using fine grid
357+
// For interpolation at grid points, result should be exact
358+
EXPECT_DOUBLE_EQ(interp.eval(0.0, 0.0, 0.0, 0.0), 0.0);
359+
EXPECT_DOUBLE_EQ(interp.eval(1.0, 1.0, 1.0, 1.0), 4.0);
360+
EXPECT_DOUBLE_EQ(interp.eval(2.0, 2.0, 2.0, 2.0), 16.0);
361+
}

0 commit comments

Comments
 (0)