Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions src/grid/cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@

from grid.basegrid import Grid, OneDGrid

from collections import deque
from typing import Optional, Callable
from scipy.spatial import cKDTree


class _HyperRectangleGrid(Grid):
def __init__(self, points, weights, shape):
Expand Down Expand Up @@ -1013,3 +1017,192 @@ def generate_cube(self, fname, data, atcoords, atnums, pseudo_numbers=None):
row_data = data.flat[i : i + num_chunks]
f.write((row_data.size * " {:12.5E}").format(*row_data))
f.write("\n")


class AdaptiveUniformGrid:
"""This is a wrapper class that provides adaptive refinement for a UniformGrid instance.

This class takes a UniformGrid object and applies a recursive subdivision
algorithm to generate a new, non-uniform grid with points concentrated in
regions of high function error, leading to more efficient and accurate integration.

The main entry point is the `refinement` method."""

def __init__(self, uniform_grid: UniformGrid):
"""initialization
Parameters
----------
uniform_grid : UniformGrid
The coarse, uniform grid that will serve as the starting point for refinement."""
if not isinstance(uniform_grid, UniformGrid):
raise ValueError("The input grid should be a UniformGrid instance.")
self.grid = uniform_grid
self.ndim = uniform_grid.ndim

def _estimate_error(self, points: np.ndarray, evaluated_points: dict) -> np.ndarray:
"""Estimates the error for each point based on the local gradient."""
errors = np.zeros(len(points))
if len(points) <= 1:
return errors

tree = cKDTree(points)
k_neighbors = self.ndim * 2 + 1
for i, point in enumerate(points):
k = min(k_neighbors, len(points))
distances, indices = tree.query(point, k=k)
neighbor_points = points[indices[1:]]
neighbor_distances = distances[1:]
if len(neighbor_points) == 0:
continue

point_val = evaluated_points[tuple(point)]

max_grad_mag = 0.0
avg_local_spacing = np.mean(neighbor_distances)
for neighbor, dist in zip(neighbor_points, neighbor_distances):
if dist == 0:
continue

neighbor_val = evaluated_points[tuple(neighbor)]
grad_mag = abs(point_val - neighbor_val) / dist
if grad_mag > max_grad_mag:
max_grad_mag = grad_mag

errors[i] = max_grad_mag * avg_local_spacing

return errors

def _find_neighbors(self, point: np.ndarray, spacing: float) -> list:
"""Notes on Neighbor Finding:
The standard method of finding neighbors by converting point indices to integer coordinates (i, j, k) is not used here.
This is because the adaptive refinement process adds new points that do not lie on the original structured grid.
For these new points, the index-based mapping will fail.
Therefore, I use real-world (x, y, z) coordinates and still implement the x ± (spacing/2) * a_i to find neighbors.
"""
neighbors = []
for axis in self.grid.axes:
axis_direction = axis / np.linalg.norm(axis)
neighbors.append(point + spacing * axis_direction)
neighbors.append(point - spacing * axis_direction)
return neighbors

def refinement(
self, func: Callable, tolerance: float = 1e-4, min_spacing: Optional[float] = None
) -> dict:
"""Drives the adaptive refinement process and returns the results.

This method starts with the initial uniform grid, refines it according to
the function `func`, and returns the final results without modifying the
original grid object.

Parameters
----------
func : Callable
The function to be integrated.
tolerance : float, optional
The error tolerance for a local point.
min_spacing : float, optional
The minimum allowed spacing for subdivision.

Returns
-------
dict
A dictionary containing the final integral value, the refined grid object,
and other statistics."""

# Initialization
initial_points = self.grid.points.copy()
initial_weights = self.grid.weights.copy()
initial_avg_spacing = np.mean([np.linalg.norm(axis) for axis in self.grid.axes])
if min_spacing is None:
min_spacing = initial_avg_spacing / 100

evaluated_points = {}
initial_values = func(initial_points)
for i, p in enumerate(initial_points):
evaluated_points[tuple(p)] = initial_values[i]

original_total_volume = np.sum(initial_weights)

# Refinement process
errors = self._estimate_error(initial_points, evaluated_points)
high_error_indices = np.where(errors > tolerance)[0]

if high_error_indices.size == 0:
final_integral = self.grid.integrate(initial_values)
return {
"integral": final_integral,
"final_grid": self.grid,
"num_points": len(initial_points),
"num_evaluations": len(evaluated_points),
}

final_points = []
unnormalized_weights = []

keep_mask = np.ones(len(initial_points), dtype=bool)
keep_mask[high_error_indices] = False

retained_points = initial_points[keep_mask]
final_points.extend(list(retained_points))
retained_weight = initial_avg_spacing ** self.ndim
unnormalized_weights.extend([retained_weight] * len(retained_points))

refinement_queue = deque(
[(initial_points[idx], initial_avg_spacing) for idx in high_error_indices]
)
processed_points_set = {tuple(p) for p in initial_points}

while refinement_queue:
point, spacing = refinement_queue.popleft()
half_spacing = spacing / 2
if half_spacing < min_spacing:
final_points.append(point)
unnormalized_weights.append(spacing**self.ndim)
continue

neighbors = self._find_neighbors(point, half_spacing)

new_points_to_eval = [p for p in neighbors if tuple(p) not in evaluated_points]
if new_points_to_eval:
new_values = func(np.array(new_points_to_eval))
for new_point, new_value in zip(new_points_to_eval, new_values):
evaluated_points[tuple(new_point)] = new_value

point_val = evaluated_points[tuple(point)]
max_grad_mag = 0
for neighbor in neighbors:
neighbor_val = evaluated_points[tuple(neighbor)]
grad_mag = abs(point_val - neighbor_val) / half_spacing
if grad_mag > max_grad_mag:
max_grad_mag = grad_mag

local_error = max_grad_mag * half_spacing

if local_error < tolerance:
final_points.append(point)
unnormalized_weights.append(spacing**self.ndim)
else:
refinement_queue.append((point, half_spacing))
for neighbor in neighbors:
child_tuple = tuple(neighbor)
if child_tuple not in processed_points_set:
refinement_queue.append((neighbor, half_spacing))
processed_points_set.add(child_tuple)

# Finalization
final_points = np.array(final_points)
final_weights = np.array(unnormalized_weights)
current_total_volume = np.sum(final_weights)
if current_total_volume > 0:
final_weights *= original_total_volume / current_total_volume
final_grid = Grid(final_points, final_weights)
final_values = np.array([evaluated_points[tuple(p)] for p in final_points])
final_integral = final_grid.integrate(final_values)

return {
"integral": final_integral,
"final_grid": final_grid,
"num_points": len(final_points),
"num_evaluations": len(evaluated_points),
}
101 changes: 101 additions & 0 deletions src/grid/tests/test_cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from grid.cubic import Tensor1DGrids, UniformGrid, _HyperRectangleGrid
from grid.onedgrid import GaussLaguerre, MidPoint

import pytest
import copy
from grid.cubic import AdaptiveUniformGrid, Grid


class TestHyperRectangleGrid(TestCase):
r"""Test HyperRectangleGrid class."""
Expand Down Expand Up @@ -1050,3 +1054,100 @@ def test_uniformgrid_points_without_rotate(self):
]
)
assert_allclose(grid.points, expected, rtol=1.0e-7, atol=1.0e-7)


class TestAdaptiveUniformGrid:
"""Tests for the new AdaptiveUniformGrid wrapper class."""

case_3d_single_peak = {
"id": "3D_Single_Peak",
"grid_setup": {
"origin": np.array([-2.0, -2.0, -2.0]),
"axes": np.diag([1.0, 1.0, 1.0]),
"shape": np.array([5, 5, 5]),
},
"func": lambda points: np.exp(-20 * np.sum(points ** 2, axis=1)),
# Analytical integral of exp(-a*r^2) in 3D is (pi/a)^(3/2)
"analytical_integral": (np.pi / 20) ** 1.5,
"tolerance": 1e-5
}

case_2d_single_peak = {
"id": "2D_Single_Peak_Centered",
"grid_setup": {
"origin": np.array([-4.0, -4.0]),
"axes": np.diag([1.0, 1.0]),
"shape": np.array([9, 9]),
},
"func": lambda points: np.exp(-50 * np.sum(points ** 2, axis=1)),
# Analytical integral of exp(-a*r^2) in 2D is pi/a
"analytical_integral": np.pi / 50,
"tolerance": 1e-4
}

case_2d_multi_peak = {
"id": "2D_Multi_Peak_Centered",
"grid_setup": {
"origin": np.array([-4.0, -4.0]),
"axes": np.diag([1.0, 1.0]),
"shape": np.array([9, 9]),
},
"func": lambda points: (
np.exp(-30 * np.sum((points - np.array([1.0, 1.0])) ** 2, axis=1)) +
np.exp(-30 * np.sum((points - np.array([-1.0, -1.0])) ** 2, axis=1))
),
# Integral is the sum of two identical Gaussians
"analytical_integral": 2 * (np.pi / 30),
"tolerance": 1e-3
}

@pytest.mark.parametrize(
"test_case",
[case_3d_single_peak, case_2d_single_peak, case_2d_multi_peak],
ids=lambda tc: tc["id"] # Use the 'id' field for clear test names in the report
)

def test_refinement_improves_accuracy(self, test_case):
"""
Tests that the adaptively refined grid yields a more accurate integral.
"""
# Arrange: Set up the test conditions and known values.

grid_setup = test_case["grid_setup"]
uniform_grid = UniformGrid(
grid_setup["origin"], grid_setup["axes"], grid_setup["shape"], weight="Rectangle"
)
test_func = test_case["func"]
analytical_integral = test_case["analytical_integral"]
tolerance = test_case["tolerance"]


adaptive_grid = AdaptiveUniformGrid(uniform_grid)

initial_values = test_func(uniform_grid.points)
initial_integral = uniform_grid.integrate(initial_values)
initial_error = abs(initial_integral - analytical_integral)

result = adaptive_grid.refinement(func=test_func, tolerance=tolerance)

refined_integral = result["integral"]
refined_error = abs(refined_integral - analytical_integral)

initial_num_points = uniform_grid.size
refined_num_points = result["num_points"]
error_reduction_factor = (
initial_error / refined_error if refined_error > 0 else float("inf")
)

print(f"\n--- Test Scenario: {test_case['id']} ---")
print(f"{'Metric':<25} | {'Initial':<25} | {'Refined':<25}")
print("-" * 80)
print(f"{'Number of Points':<25} | {initial_num_points:<25} | {refined_num_points:<25}")
print(f"{'Integration Error':<25} | {initial_error:<25.10e} | {refined_error:<25.10e}")
print("-" * 80)
print(f"Error Reduction Factor: {error_reduction_factor:.2f}x")
print(f"Point Count Increase: {refined_num_points - initial_num_points}")
print("--------------------------------------")

assert refined_error < initial_error
assert result["num_points"] > uniform_grid.size