@@ -74,6 +74,37 @@ def theta_grid(
7474 return grid
7575
7676
77+ def _reshape_array_to_expected_shape (da , dims , ** bins ):
78+ if da .bins :
79+ da = da .bins .concat (set (da .dims ) - set (dims ))
80+ elif set (da .dims ) > set (dims ):
81+ raise ValueError (
82+ f'Histogram must have exactly the dimensions'
83+ f' { set (dims )} but got { set (da .dims )} '
84+ )
85+
86+ if not set (da .dims ).union (set (bins )) >= set (dims ):
87+ raise ValueError (
88+ f'Could not find bins for dimensions:'
89+ f' { set (dims ) - set (da .dims ).union (set (bins ))} '
90+ )
91+
92+ if da .bins or not set (da .dims ) == set (dims ):
93+ da = da .hist (** bins )
94+
95+ return da .transpose (dims )
96+
97+
98+ def _repeat_variable_argument (n , arg ):
99+ return (
100+ (None ,) * n
101+ if arg is None
102+ else (arg ,) * n
103+ if isinstance (arg , sc .Variable )
104+ else arg
105+ )
106+
107+
77108def wavelength_theta_figure (
78109 da : sc .DataArray | Sequence [sc .DataArray ],
79110 * ,
@@ -93,40 +124,29 @@ def wavelength_theta_figure(
93124 )
94125
95126 wavelength_bins , theta_bins = (
96- (None ,) * len (da )
97- if v is None
98- else (v ,) * len (da )
99- if isinstance (v , sc .Variable )
100- else v
101- for v in (wavelength_bins , theta_bins )
127+ _repeat_variable_argument (len (da ), arg ) for arg in (wavelength_bins , theta_bins )
102128 )
103129
104130 hs = []
105131 for d , wavelength_bin , theta_bin in zip (
106132 da , wavelength_bins , theta_bins , strict = True
107133 ):
108- if d .bins :
109- d = d .bins .concat (set (d .dims ) - {"wavelength" , "theta" })
110- all_coords = {* d .coords , * (d .bins or d ).coords }
111- if 'wavelength' not in all_coords or 'theta' not in all_coords :
112- raise ValueError ('Data must have wavelength and theta coord' )
113- if d .bins or set (d .dims ) != {"wavelength" , "theta" }:
114- bins = {}
115- if 'sample_rotation' in d .coords and 'detector_rotation' in d .coords :
134+ bins = {}
135+ if wavelength_bin is not None :
136+ bins ['wavelength' ] = wavelength_bin
137+
138+ if theta_bin is not None :
139+ bins ['theta' ] = theta_bin
140+ else :
141+ if (
142+ 'theta' not in d .dims
143+ and 'sample_rotation' in d .coords
144+ and 'detector_rotation' in d .coords
145+ ):
116146 bins ['theta' ] = theta_grid (
117147 nu = d .coords ['detector_rotation' ], mu = d .coords ['sample_rotation' ]
118148 )
119- if theta_bin is not None :
120- bins ['theta' ] = theta_bin
121- if wavelength_bin is not None :
122- bins ['wavelength' ] = wavelength_bin
123- if 'theta' not in d .dims and 'theta' not in bins :
124- raise ValueError ('No theta binning provided' )
125- if 'wavelength' not in d .dims and 'wavelength' not in bins :
126- raise ValueError ('No wavelength binning provided' )
127- d = d .hist (** bins )
128-
129- hs .append (d .transpose (('theta' , 'wavelength' )))
149+ hs .append (_reshape_array_to_expected_shape (d , ('theta' , 'wavelength' ), ** bins ))
130150
131151 kwargs .setdefault ('cbar' , True )
132152 kwargs .setdefault ('norm' , 'log' )
@@ -157,35 +177,27 @@ def q_theta_figure(
157177 )
158178
159179 q_bins , theta_bins = (
160- (None ,) * len (da )
161- if v is None
162- else (v ,) * len (da )
163- if isinstance (v , sc .Variable )
164- else v
165- for v in (q_bins , theta_bins )
180+ _repeat_variable_argument (len (da ), arg ) for arg in (q_bins , theta_bins )
166181 )
167182
168183 hs = []
169184 for d , q_bin , theta_bin in zip (da , q_bins , theta_bins , strict = True ):
170- if d .bins :
171- d = d .bins .concat (set (d .dims ) - {'theta' , 'Q' })
172-
173- all_coords = {* d .coords , * (d .bins or d ).coords }
174- if 'theta' not in all_coords or 'Q' not in all_coords :
175- raise ValueError ('Data must have theta and Q coord' )
176- if d .bins or set (d .dims ) != {"theta" , "Q" }:
177- bins = {}
178- if theta_bin is not None :
179- bins ['theta' ] = theta_bin
180- if q_bin is not None :
181- bins ['Q' ] = q_bin
182- if 'theta' not in d .dims and 'theta' not in bins :
183- raise ValueError ('No theta binning provided' )
184- if 'Q' not in d .dims and 'Q' not in bins :
185- raise ValueError ('No Q binning provided' )
186- d = d .hist (** bins )
187-
188- hs .append (d .transpose (('theta' , 'Q' )))
185+ bins = {}
186+ if q_bin is not None :
187+ bins ['Q' ] = q_bin
188+
189+ if theta_bin is not None :
190+ bins ['theta' ] = theta_bin
191+ else :
192+ if (
193+ 'theta' not in d .dims
194+ and 'sample_rotation' in d .coords
195+ and 'detector_rotation' in d .coords
196+ ):
197+ bins ['theta' ] = theta_grid (
198+ nu = d .coords ['detector_rotation' ], mu = d .coords ['sample_rotation' ]
199+ )
200+ hs .append (_reshape_array_to_expected_shape (d , ('theta' , 'Q' ), ** bins ))
189201
190202 kwargs .setdefault ('cbar' , True )
191203 kwargs .setdefault ('norm' , 'log' )
@@ -202,28 +214,17 @@ def wavelength_z_figure(
202214 if isinstance (da , sc .DataArray ):
203215 return wavelength_z_figure ((da ,), wavelength_bins = (wavelength_bins ,), ** kwargs )
204216
205- (wavelength_bins ,) = (
206- (None ,) * len (da )
207- if v is None
208- else (v ,) * len (da )
209- if isinstance (v , sc .Variable )
210- else v
211- for v in (wavelength_bins ,)
212- )
217+ wavelength_bins = _repeat_variable_argument (len (da ), wavelength_bins )
213218
214219 hs = []
215220 for d , wavelength_bin in zip (da , wavelength_bins , strict = True ):
216- if d .bins :
217- d = d .bins .concat (set (d .dims ) - {'blade' , 'wire' , 'wavelength' })
218- bins = {}
219- if wavelength_bin is not None :
220- bins ['wavelength' ] = wavelength_bin
221- if 'wavelength' not in d .dims and 'wavelength' not in bins :
222- raise ValueError ('No wavelength binning provided' )
223- d = d .hist (** bins )
221+ bins = {}
222+ if wavelength_bin is not None :
223+ bins ['wavelength' ] = wavelength_bin
224224
225+ d = _reshape_array_to_expected_shape (d , ("blade" , "wire" , "wavelength" ), ** bins )
225226 d = d .flatten (("blade" , "wire" ), to = "z_index" )
226- hs .append (d . transpose (( 'z_index' , 'wavelength' )) )
227+ hs .append (d )
227228
228229 kwargs .setdefault ('cbar' , True )
229230 kwargs .setdefault ('norm' , 'log' )
0 commit comments