Skip to content

Commit 60bf718

Browse files
committed
refactor: extract repeated logic
1 parent 1390c80 commit 60bf718

File tree

2 files changed

+83
-82
lines changed

2 files changed

+83
-82
lines changed

src/ess/amor/utils.py

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,37 @@ def theta_grid(
7474
return grid
7575

7676

77+
def _reshape_array_to_expected_shape(da, dims, **bins):
78+
if da.bins:
79+
da = da.bins.concat(set(da.dims) - set(dims))
80+
elif set(da.dims) > set(dims):
81+
raise ValueError(
82+
f'Histogram must have exactly the dimensions'
83+
f' {set(dims)} but got {set(da.dims)}'
84+
)
85+
86+
if not set(da.dims).union(set(bins)) >= set(dims):
87+
raise ValueError(
88+
f'Could not find bins for dimensions:'
89+
f' {set(dims) - set(da.dims).union(set(bins))}'
90+
)
91+
92+
if da.bins or not set(da.dims) == set(dims):
93+
da = da.hist(**bins)
94+
95+
return da.transpose(dims)
96+
97+
98+
def _repeat_variable_argument(n, arg):
99+
return (
100+
(None,) * n
101+
if arg is None
102+
else (arg,) * n
103+
if isinstance(arg, sc.Variable)
104+
else arg
105+
)
106+
107+
77108
def wavelength_theta_figure(
78109
da: sc.DataArray | Sequence[sc.DataArray],
79110
*,
@@ -93,40 +124,29 @@ def wavelength_theta_figure(
93124
)
94125

95126
wavelength_bins, theta_bins = (
96-
(None,) * len(da)
97-
if v is None
98-
else (v,) * len(da)
99-
if isinstance(v, sc.Variable)
100-
else v
101-
for v in (wavelength_bins, theta_bins)
127+
_repeat_variable_argument(len(da), arg) for arg in (wavelength_bins, theta_bins)
102128
)
103129

104130
hs = []
105131
for d, wavelength_bin, theta_bin in zip(
106132
da, wavelength_bins, theta_bins, strict=True
107133
):
108-
if d.bins:
109-
d = d.bins.concat(set(d.dims) - {"wavelength", "theta"})
110-
all_coords = {*d.coords, *(d.bins or d).coords}
111-
if 'wavelength' not in all_coords or 'theta' not in all_coords:
112-
raise ValueError('Data must have wavelength and theta coord')
113-
if d.bins or set(d.dims) != {"wavelength", "theta"}:
114-
bins = {}
115-
if 'sample_rotation' in d.coords and 'detector_rotation' in d.coords:
134+
bins = {}
135+
if wavelength_bin is not None:
136+
bins['wavelength'] = wavelength_bin
137+
138+
if theta_bin is not None:
139+
bins['theta'] = theta_bin
140+
else:
141+
if (
142+
'theta' not in d.dims
143+
and 'sample_rotation' in d.coords
144+
and 'detector_rotation' in d.coords
145+
):
116146
bins['theta'] = theta_grid(
117147
nu=d.coords['detector_rotation'], mu=d.coords['sample_rotation']
118148
)
119-
if theta_bin is not None:
120-
bins['theta'] = theta_bin
121-
if wavelength_bin is not None:
122-
bins['wavelength'] = wavelength_bin
123-
if 'theta' not in d.dims and 'theta' not in bins:
124-
raise ValueError('No theta binning provided')
125-
if 'wavelength' not in d.dims and 'wavelength' not in bins:
126-
raise ValueError('No wavelength binning provided')
127-
d = d.hist(**bins)
128-
129-
hs.append(d.transpose(('theta', 'wavelength')))
149+
hs.append(_reshape_array_to_expected_shape(d, ('theta', 'wavelength'), **bins))
130150

131151
kwargs.setdefault('cbar', True)
132152
kwargs.setdefault('norm', 'log')
@@ -157,35 +177,27 @@ def q_theta_figure(
157177
)
158178

159179
q_bins, theta_bins = (
160-
(None,) * len(da)
161-
if v is None
162-
else (v,) * len(da)
163-
if isinstance(v, sc.Variable)
164-
else v
165-
for v in (q_bins, theta_bins)
180+
_repeat_variable_argument(len(da), arg) for arg in (q_bins, theta_bins)
166181
)
167182

168183
hs = []
169184
for d, q_bin, theta_bin in zip(da, q_bins, theta_bins, strict=True):
170-
if d.bins:
171-
d = d.bins.concat(set(d.dims) - {'theta', 'Q'})
172-
173-
all_coords = {*d.coords, *(d.bins or d).coords}
174-
if 'theta' not in all_coords or 'Q' not in all_coords:
175-
raise ValueError('Data must have theta and Q coord')
176-
if d.bins or set(d.dims) != {"theta", "Q"}:
177-
bins = {}
178-
if theta_bin is not None:
179-
bins['theta'] = theta_bin
180-
if q_bin is not None:
181-
bins['Q'] = q_bin
182-
if 'theta' not in d.dims and 'theta' not in bins:
183-
raise ValueError('No theta binning provided')
184-
if 'Q' not in d.dims and 'Q' not in bins:
185-
raise ValueError('No Q binning provided')
186-
d = d.hist(**bins)
187-
188-
hs.append(d.transpose(('theta', 'Q')))
185+
bins = {}
186+
if q_bin is not None:
187+
bins['Q'] = q_bin
188+
189+
if theta_bin is not None:
190+
bins['theta'] = theta_bin
191+
else:
192+
if (
193+
'theta' not in d.dims
194+
and 'sample_rotation' in d.coords
195+
and 'detector_rotation' in d.coords
196+
):
197+
bins['theta'] = theta_grid(
198+
nu=d.coords['detector_rotation'], mu=d.coords['sample_rotation']
199+
)
200+
hs.append(_reshape_array_to_expected_shape(d, ('theta', 'Q'), **bins))
189201

190202
kwargs.setdefault('cbar', True)
191203
kwargs.setdefault('norm', 'log')
@@ -202,28 +214,17 @@ def wavelength_z_figure(
202214
if isinstance(da, sc.DataArray):
203215
return wavelength_z_figure((da,), wavelength_bins=(wavelength_bins,), **kwargs)
204216

205-
(wavelength_bins,) = (
206-
(None,) * len(da)
207-
if v is None
208-
else (v,) * len(da)
209-
if isinstance(v, sc.Variable)
210-
else v
211-
for v in (wavelength_bins,)
212-
)
217+
wavelength_bins = _repeat_variable_argument(len(da), wavelength_bins)
213218

214219
hs = []
215220
for d, wavelength_bin in zip(da, wavelength_bins, strict=True):
216-
if d.bins:
217-
d = d.bins.concat(set(d.dims) - {'blade', 'wire', 'wavelength'})
218-
bins = {}
219-
if wavelength_bin is not None:
220-
bins['wavelength'] = wavelength_bin
221-
if 'wavelength' not in d.dims and 'wavelength' not in bins:
222-
raise ValueError('No wavelength binning provided')
223-
d = d.hist(**bins)
221+
bins = {}
222+
if wavelength_bin is not None:
223+
bins['wavelength'] = wavelength_bin
224224

225+
d = _reshape_array_to_expected_shape(d, ("blade", "wire", "wavelength"), **bins)
225226
d = d.flatten(("blade", "wire"), to="z_index")
226-
hs.append(d.transpose(('z_index', 'wavelength')))
227+
hs.append(d)
227228

228229
kwargs.setdefault('cbar', True)
229230
kwargs.setdefault('norm', 'log')

tests/amor/utils_test.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def q_bins():
3737

3838

3939
def test_wavelength_figure_table(da, wavelength_bins, theta_bins):
40-
with pytest.raises(ValueError, match='binning provided'):
40+
with pytest.raises(ValueError, match='Could not find bins'):
4141
wavelength_theta_figure(da)
4242

43-
with pytest.raises(ValueError, match='binning provided'):
43+
with pytest.raises(ValueError, match='Could not find bins'):
4444
wavelength_theta_figure(da, wavelength_bins=wavelength_bins)
4545

46-
with pytest.raises(ValueError, match='binning provided'):
46+
with pytest.raises(ValueError, match='Could not find bins'):
4747
wavelength_theta_figure(da, theta_bins=theta_bins)
4848

4949
assert wavelength_theta_figure(
@@ -52,10 +52,10 @@ def test_wavelength_figure_table(da, wavelength_bins, theta_bins):
5252

5353

5454
def test_wavelength_figure_binned(da, wavelength_bins, theta_bins):
55-
with pytest.raises(ValueError, match='binning provided'):
55+
with pytest.raises(ValueError, match='Could not find bins'):
5656
wavelength_theta_figure(da.bin(wavelength=3))
5757

58-
with pytest.raises(ValueError, match='binning provided'):
58+
with pytest.raises(ValueError, match='Could not find bins'):
5959
wavelength_theta_figure(da.bin(theta=3))
6060

6161
assert wavelength_theta_figure(da.bin(wavelength=3, theta=3))
@@ -64,10 +64,10 @@ def test_wavelength_figure_binned(da, wavelength_bins, theta_bins):
6464

6565

6666
def test_wavelength_figure_hist(da, wavelength_bins, theta_bins):
67-
with pytest.raises(ValueError, match='must have wavelength and theta coord'):
67+
with pytest.raises(ValueError, match='Could not find bins'):
6868
wavelength_theta_figure(da.hist(wavelength=3))
6969

70-
with pytest.raises(ValueError, match='must have wavelength and theta coord'):
70+
with pytest.raises(ValueError, match='Could not find bins'):
7171
wavelength_theta_figure(da.hist(theta=3))
7272

7373
assert wavelength_theta_figure(da.hist(wavelength=3, theta=3))
@@ -140,23 +140,23 @@ def test_wavelength_figure_accepts_additional_plot_kwargs(
140140

141141

142142
def test_q_figure_table(da, q_bins, theta_bins):
143-
with pytest.raises(ValueError, match='binning provided'):
143+
with pytest.raises(ValueError, match='Could not find bins'):
144144
q_theta_figure(da)
145145

146-
with pytest.raises(ValueError, match='binning provided'):
146+
with pytest.raises(ValueError, match='Could not find bins'):
147147
q_theta_figure(da, q_bins=q_bins)
148148

149-
with pytest.raises(ValueError, match='binning provided'):
149+
with pytest.raises(ValueError, match='Could not find bins'):
150150
q_theta_figure(da, theta_bins=theta_bins)
151151

152152
assert q_theta_figure(da, theta_bins=theta_bins, q_bins=q_bins)
153153

154154

155155
def test_q_figure_binned(da, q_bins, theta_bins):
156-
with pytest.raises(ValueError, match='binning provided'):
156+
with pytest.raises(ValueError, match='Could not find bins'):
157157
q_theta_figure(da.bin(Q=3))
158158

159-
with pytest.raises(ValueError, match='binning provided'):
159+
with pytest.raises(ValueError, match='Could not find bins'):
160160
q_theta_figure(da.bin(theta=3))
161161

162162
assert q_theta_figure(da.bin(Q=3, theta=3))
@@ -165,10 +165,10 @@ def test_q_figure_binned(da, q_bins, theta_bins):
165165

166166

167167
def test_q_figure_hist(da, q_bins, theta_bins):
168-
with pytest.raises(ValueError, match='must have theta and Q coord'):
168+
with pytest.raises(ValueError, match='Could not find bins'):
169169
q_theta_figure(da.hist(Q=3))
170170

171-
with pytest.raises(ValueError, match='must have theta and Q coord'):
171+
with pytest.raises(ValueError, match='Could not find bins'):
172172
q_theta_figure(da.hist(theta=3))
173173

174174
assert q_theta_figure(da.hist(Q=3, theta=3))
@@ -228,7 +228,7 @@ def test_q_figure_can_pass_additional_plot_kwargs(da, q_bins, theta_bins):
228228
def test_z_figure_binned(da, wavelength_bins):
229229
da = da.group('z_index').fold('z_index', dims=('blade', 'wire'), shape=(2, 5))
230230

231-
with pytest.raises(ValueError, match='binning provided'):
231+
with pytest.raises(ValueError, match='Could not find bins'):
232232
wavelength_z_figure(da)
233233

234234
assert wavelength_z_figure(da.bin(wavelength=3))

0 commit comments

Comments
 (0)