Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion curepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
# Utilities
from curepy.utilities.plotting import plot_corner
from curepy.utilities.maths import lnlike
from curepy.utilities.distributions import ln_uniform, ln_normal, ln_multi_normal, ln_trunc_normal
from curepy.utilities.distributions import (
ln_uniform,
ln_normal,
ln_multi_normal,
ln_trunc_normal,
)
from curepy.utilities.utilities import flatten_array, reshape_array, format_correlation

from ._version import get_versions
Expand Down
33 changes: 31 additions & 2 deletions curepy/container/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(
u_y_rand: Optional[np.ndarray] = None,
u_y_syst: Optional[np.ndarray] = None,
corr_y: Optional[Union[str, np.ndarray]] = None,
skip_invcov: bool = False,
) -> None:
"""
Container class for measurement variable data.
Expand All @@ -29,6 +30,7 @@ def __init__(
Accepted values: ``None``, ``"rand"`` (random), ``"syst"``
(systematic), or a square matrix whose side length equals the
length of ``y``.
:param skip_invcov: If ``True``, skip the computation of the inverse covariance matrix (which is only needed for certain retrieval methods like optimal estimation).
"""

u_y_total, corr_y = self._format_uncertainty(
Expand All @@ -43,9 +45,13 @@ def __init__(

self.corr_y = util.format_correlation(self.y_flat, corr_y)

self.corr_y, self.cholesky, self.W = self.return_corr_cholesky_whitening(
self.corr_y
)

self._check_shapes(self.y_flat, self.u_y_flat, self.corr_y)

if corr_y is not None:
if corr_y is not None and not skip_invcov:
self.invcov = self.calculate_inv_cov(self.u_y_flat, self.corr_y)
else:
self.invcov = None
Expand Down Expand Up @@ -136,6 +142,30 @@ def _format_uncertainty(u_total, u_rand, u_syst, corr):
tot_corr = cm.convert_cov_to_corr(tot_cov, tot)
return tot, tot_corr

@staticmethod
def return_corr_cholesky_whitening(corr: Optional[np.ndarray]) -> tuple:
"""
Return the correlation matrix, its Cholesky decomposition, and the whitening matrix.

:param corr: Correlation matrix, or ``None``.
:returns: Tuple of ``(corr, cholesky, W)`` where ``cholesky`` is
the Cholesky decomposition of the correlation matrix, or
``None`` if ``corr`` is ``None``, and ``W`` is the whitening matrix.
"""
if corr is not None:
try:
cholesky = np.linalg.cholesky(corr)
W = np.linalg.solve(cholesky, np.eye(cholesky.shape[0]))
return corr, cholesky, W
except np.linalg.LinAlgError:
# If the correlation matrix is not positive definite, use the nearest positive definite matrix
corr_pd = cm.nearestPD_cholesky(corr, return_cholesky=False, corr=True)
cholesky = np.linalg.cholesky(corr_pd)
W = np.linalg.solve(cholesky, np.eye(cholesky.shape[0]))
return corr_pd, cholesky, W
else:
return None, None, None

@staticmethod
def calculate_inv_cov(unc: np.ndarray, corr: np.ndarray) -> np.ndarray:
"""
Expand All @@ -151,5 +181,4 @@ def calculate_inv_cov(unc: np.ndarray, corr: np.ndarray) -> np.ndarray:
if np.array_equal(cov, np.diag(np.diag(cov))):
return np.diag(1 / np.diag(cov))
else:
# might need a check for PD here
return np.linalg.inv(cov)
2 changes: 2 additions & 0 deletions curepy/container/tests/test_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_calculate_inv_cov(self, mock_convert_corr_to_cov, mock_inv):
def test_init_format_correlation_called(
self, mock_format, mock_check, mock_convert
):
# Configure mock to return a valid correlation matrix
mock_format.return_value = np.eye(len(y))

meas = Measurement(y, u_y, corr_y="rand")

Expand Down
43 changes: 26 additions & 17 deletions curepy/retrieval_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,35 +127,44 @@ def find_chisum(
).flatten()
)
diff = modelled_data - self.retrieval_input.measurement_obj.y_flat

# Only normalize by u_y_flat if it's available
if self.retrieval_input.measurement_obj.u_y_flat is not None:
diff_norm = diff / self.retrieval_input.measurement_obj.u_y_flat
else:
diff_norm = diff

if np.isfinite(np.sum(diff)):
if self.retrieval_input.measurement_obj.invcov is None:
return np.sum(
(diff) ** 2 / self.retrieval_input.measurement_obj.u_y_flat**2
)
if self.retrieval_input.measurement_obj.cholesky is None:
chisq = np.sum(
(diff_norm) ** 2
) # this is equivalent to using an identity matrix for the inverse covariance, which is appropriate when only uncorrelated uncertainties are available

else:
if len(repeat_dims) == 0:
return np.dot(
np.dot(diff.T, self.retrieval_input.measurement_obj.invcov),
diff,
)
y = self.retrieval_input.measurement_obj.W @ diff_norm
chisq = y.T @ y
elif len(repeat_dims) == 1:
sum = 0
for i in range(diff.shape[repeat_dims[0]]):
diffi = np.take(diff, i, repeat_dims[0])
sum += np.dot(
np.dot(
diffi.T, self.retrieval_input.measurement_obj.invcov
),
diffi,
)
return sum
diff_norm_i = np.take(diff_norm, i, repeat_dims[0])
y = self.retrieval_input.measurement_obj.W @ diff_norm_i
sum += y.T @ y
chisq = sum
else:
raise ValueError(
"Methods for multiple repeat dimensions are not yet implemented,"
)
else:
print("The difference between model and observations is infinite")
return np.inf
chisq = np.inf

if chisq < 0:
raise ValueError(
"The chi-squared cost is negative, which should not be possible. Check the inputs and the measurement function for errors."
)

return chisq

def lnprob(self, theta: np.ndarray) -> float:
"""
Expand Down
2 changes: 1 addition & 1 deletion curepy/retrieval_methods/optimal_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _run_retrieval(
self.retrieval_input.measurement_function_obj.initial_guess
)

res = minimize(-self.lnprob, theta_0)
res = minimize(lambda theta: -self.lnprob(theta), theta_0)

if self.Jx is None:
Jx = self.calculate_Jx(res.x)
Expand Down
9 changes: 8 additions & 1 deletion curepy/retrieval_methods/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def make_mock_retrieval_input_for_chisum(
invcov=None,
u_y=None,
b=None,
L=None,
W=None,
):
retrieval_input = RetrievalInput()
# measurement function object
Expand All @@ -27,6 +29,9 @@ def make_mock_retrieval_input_for_chisum(
retrieval_input.measurement_obj.y_flat = np.array(y_flat)
retrieval_input.measurement_obj.invcov = invcov
retrieval_input.measurement_obj.u_y_flat = u_y
retrieval_input.measurement_obj.cholesky = L
retrieval_input.measurement_obj.W = W

# ancillary
retrieval_input.ancillary_obj = MagicMock()
retrieval_input.ancillary_obj.b = b
Expand Down Expand Up @@ -137,7 +142,7 @@ def test_multiple_repeat_dims_raises_error(self):
retrieval_input = make_mock_retrieval_input_for_chisum(
measurement_function_output=np.array([1.0, 2.0]),
y_flat=np.array([1.0, 1.0]),
invcov=np.eye(2),
L=np.eye(2),
u_y=None,
b=None,
)
Expand Down Expand Up @@ -181,6 +186,8 @@ def test_chisum_with_invcov_no_repeat(self):
invcov=invcov,
u_y=None,
b=None,
L=np.eye(2),
W=np.eye(2),
)

dr = DummyRetrieval()
Expand Down
2 changes: 1 addition & 1 deletion examples/multidimensional_MCMC_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def quadratic(a, b, c, x, d):
y = data + noise

meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10])
meas = Measurement(y, noise, "rand")
meas = Measurement(y, noise, corr_y="rand")
ancill = AncillaryParameter([x, d], [None, 1], [None, None], b_MC_steps=3)

inputs = RetrievalInput(meas_func, meas, ancill)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
MeasurementFunction,
Measurement,
AncillaryParameter,
LPU,
OE,
RetrievalInput,
)

Expand All @@ -24,7 +24,7 @@ def quadratic(a, b, c, x, d):
y = data + noise

meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10])
meas = Measurement(y, noise, "rand")
meas = Measurement(y, noise, corr_y="rand")
ancill = AncillaryParameter(
[x, d],
[0.01 * np.ones_like(x), 1],
Expand All @@ -41,7 +41,7 @@ def quadratic(a, b, c, x, d):

inputs = RetrievalInput(meas_func, meas, ancill)

ret = LPU()
ret = OE()

results = ret.run_retrieval(inputs)

Expand Down
2 changes: 1 addition & 1 deletion examples/simple_quadratic_MCMC_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def quadratic(a, b, c, x, d):
y = data + noise

meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10])
meas = Measurement(y, noise, np.eye(len(x)))
meas = Measurement(y, noise, corr_y=np.eye(len(x)))
ancill = AncillaryParameter([x, d], [None, 1], [np.eye(len(x)), None], b_MC_steps=3)
prior = Prior(
["normal"] * 3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
MeasurementFunction,
Measurement,
AncillaryParameter,
LPU,
OE,
RetrievalInput,
Prior,
)
Expand All @@ -26,7 +26,7 @@ def quadratic(a, b, c, x, d):
y = data + noise

meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10])
meas = Measurement(y, noise, np.eye(len(x)))
meas = Measurement(y, noise, corr_y=np.eye(len(x)))
ancill = AncillaryParameter([x, d], [None, 0.05], [None, None])
prior = Prior(
["normal"] * 3,
Expand All @@ -36,12 +36,12 @@ def quadratic(a, b, c, x, d):

inputs = RetrievalInput(meas_func, meas, ancill, prior)

ret = LPU()
ret = OE()

results = ret.run_retrieval(inputs)

print(results.values)
print(results.uncertainties)
plt.plot(x, quadratic(*results.values, x, d))
plt.scatter(x, y, alpha=0.5, c="orange")
plt.savefig(os.path.join(example_dir, "LPU_test.png"))
plt.savefig(os.path.join(example_dir, "OE_test.png"))
Loading