Skip to content

Commit 7cd2333

Browse files
committed
Added in unit handling for Astropy
The model can now be applied to a spectrum which can be fit by the astropy fitting routines with units handled correctly. However there is still an issue with fitting theta as a free parameter, practically this isnt an issue, however should in principle work.
1 parent b15fe40 commit 7cd2333

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

sunkit_spex/models/physical/albedo.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,22 @@ class Albedo(FittableModel):
8282
)
8383
anisotropy = Parameter(default=1, description="The anisotropy used for albedo correction", fixed=True)
8484

85+
_input_units_allow_dimensionless = True
86+
8587
def __init__(self, *args, **kwargs):
8688
self.energy_edges = kwargs.pop("energy_edges")
8789
super().__init__(*args, **kwargs)
8890

8991
def evaluate(self, spectrum, theta, anisotropy):
90-
albedo_matrix = get_albedo_matrix(self.energy_edges, self.theta, anisotropy)
92+
if hasattr(theta, "unit"):
93+
albedo_matrix = get_albedo_matrix(self.energy_edges, theta, anisotropy)
94+
else:
95+
albedo_matrix = get_albedo_matrix(self.energy_edges, theta*u.deg, anisotropy)
9196
return spectrum + spectrum @ albedo_matrix
9297

98+
def _parameter_units_for_data_units(self, inputs_unit, outputs_unit):
99+
return {"theta": u.deg}
100+
93101

94102
@lru_cache
95103
def _get_green_matrix(theta: float) -> RegularGridInterpolator:
@@ -179,7 +187,7 @@ def _calculate_albedo_matrix(energy_edges: tuple[float], theta: float, anisotrop
179187

180188

181189
@u.quantity_input
182-
def get_albedo_matrix(energy_edges: Quantity[u.keV], theta: Quantity[u.deg], anisotropy=1):
190+
def get_albedo_matrix(energy_edges: Quantity[u.keV], theta:Quantity[u.deg], anisotropy=1):
183191
r"""
184192
Get albedo correction matrix.
185193
@@ -211,6 +219,7 @@ def get_albedo_matrix(energy_edges: Quantity[u.keV], theta: Quantity[u.deg], ani
211219
if energy_edges[0].to_value(u.keV) < 3 or energy_edges[-1].to_value(u.keV) > 600:
212220
raise ValueError("Supported energy range 3 <= E <= 600 keV")
213221
theta = np.array(theta).squeeze() << theta.unit
222+
# theta = np.array(theta)*u.deg
214223
if np.abs(theta) > 90 * u.deg:
215224
raise ValueError(f"Theta must be between -90 and 90 degrees: {theta}.")
216225
anisotropy = np.array(anisotropy).squeeze()

0 commit comments

Comments
 (0)