Skip to content

Commit 53d9b08

Browse files
authored
Merge pull request #2748 from kif/2747_multi_module
Start to implement a multi-module detector
2 parents 005db36 + 857a914 commit 53d9b08

File tree

6 files changed

+468
-3
lines changed

6 files changed

+468
-3
lines changed

src/pyFAI/detectors/meson.build

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ py.install_sources(
1212
'_rayonix.py',
1313
'_xspectrum.py',
1414
'sensors.py',
15-
'orientation.py'],
15+
'orientation.py',
16+
'multi_module.py'],
1617
pure: false, # Will be installed next to binaries
1718
subdir: 'pyFAI/detectors' # Folder relative to site-packages to install to
1819
)
Lines changed: 383 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Project: Azimuthal integration
5+
# https://github.com/silx-kit/pyFAI
6+
#
7+
# Copyright (C) 2025-2026 European Synchrotron Radiation Facility, Grenoble, France
8+
#
9+
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
10+
#
11+
# Permission is hereby granted, free of charge, to any person obtaining a copy
12+
# of this software and associated documentation files (the "Software"), to deal
13+
# in the Software without restriction, including without limitation the rights
14+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15+
# copies of the Software, and to permit persons to whom the Software is
16+
# furnished to do so, subject to the following conditions:
17+
#
18+
# The above copyright notice and this permission notice shall be included in
19+
# all copies or substantial portions of the Software.
20+
#
21+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27+
# THE SOFTWARE.
28+
#
29+
30+
"""Multi-module detectors:
31+
32+
This module contains some helper function to define a detector from several modules
33+
and later-on refine this module position from powder diffraction data
34+
as demonstrated in https://doi.org/10.3390/cryst12020255
35+
"""
36+
37+
__author__ = "Jérôme Kieffer"
38+
__contact__ = "Jerome.Kieffer@ESRF.eu"
39+
__license__ = "MIT"
40+
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
41+
__date__ = "13/01/2026"
42+
__status__ = "development"
43+
44+
from math import sin, cos, pi
45+
from dataclasses import dataclass
46+
import numpy
47+
from scipy import ndimage, optimize
48+
from ..control_points import ControlPoints
49+
from ..ext import _geometry
50+
from ..io.ponifile import PoniFile
51+
from ..third_party.classproperties import classproperty
52+
53+
54+
module_d = numpy.dtype(
55+
[
56+
("d0", numpy.float64),
57+
("d1", numpy.float64),
58+
("ring", numpy.int32),
59+
("module", numpy.int32),
60+
]
61+
)
62+
63+
64+
# Those are the optimizable parameters ... 2 translations and one rotation.
65+
@dataclass
66+
class ModuleParam:
67+
d0: float = 0.0
68+
d1: float = 0.0
69+
rot: float = 0.0
70+
71+
def set(self, iterable):
72+
self.d0, self.d1, self.rot = iterable[:3]
73+
74+
def get(self):
75+
return (self.d0, self.d1, self.rot)
76+
77+
@classproperty
78+
def nb_param(cls):
79+
return len(cls.__dataclass_fields__)
80+
81+
82+
@dataclass
83+
class PoniParam:
84+
dist: float = 0.0
85+
poni1: float = 0.0
86+
poni2: float = 0.0
87+
rot1: float = 0.0
88+
rot2: float = 0.0
89+
# rot3:float=0.0
90+
# wavelength:float=0.0
91+
92+
@classproperty
93+
def nb_param(self):
94+
return len(self.__dataclass_fields__)
95+
96+
97+
class SingleModule:
98+
def __init__(self, detector, mask, index=None, fixed=False):
99+
self.parent_detector = detector
100+
self.parent_index = index
101+
if (index is not None) and index <= mask.max():
102+
self.mask = mask == index
103+
else:
104+
self.mask = mask
105+
self.fixed = False
106+
self.param = ModuleParam()
107+
self.center = None
108+
self.bounding_box = None
109+
self.calc_bounding_box()
110+
111+
def __repr__(self):
112+
return (
113+
f"Module centered at ({self.center[0, 0]:.1f}, {self.center[1, 0]:.1f})"
114+
+ (", fixed." if self.fixed else ".")
115+
)
116+
117+
def calc_bounding_box(self):
118+
d0, d1 = numpy.where(self.mask)
119+
d0m = d0.min()
120+
d0M = d0.max()
121+
d1m = d1.min()
122+
d1M = d1.max()
123+
self.center = numpy.atleast_2d([0.5 * (d0M + d0m + 1), 0.5 * (d1M + d1m + 1)]).T
124+
self.bounding_box = (slice(d0m, d0M + 1), slice(d1m, d1M + 1))
125+
return self.bounding_box
126+
127+
def calc_displacement_map(self, d1=None, d2=None, param=None):
128+
if d1 is None and d2 is None:
129+
full_detector = True
130+
p1, p2, _ = self.parent_detector.calc_cartesian_positions()
131+
d1 = p1 / self.parent_detector.pixel1
132+
d2 = p2 / self.parent_detector.pixel2
133+
mp1 = d1[self.mask]
134+
mp2 = d2[self.mask]
135+
else:
136+
full_detector = False
137+
mp1 = d1
138+
mp2 = d2
139+
140+
param = param or self.param
141+
142+
mpc = numpy.vstack((mp1.ravel(), mp2.ravel()))
143+
if not self.fixed:
144+
self.center
145+
mpc -= self.center
146+
rot = param.rot
147+
c, s = cos(rot), sin(rot)
148+
rotm = numpy.array([[c, -s], [s, c]])
149+
mpc = (
150+
numpy.dot(rotm, mpc)
151+
+ self.center
152+
+ numpy.atleast_2d([param.d0, param.d1]).T
153+
)
154+
if full_detector:
155+
mshape = mp1.shape
156+
p1[self.mask] = mpc[0].reshape(mshape)
157+
p2[self.mask] = mpc[1].reshape(mshape)
158+
else:
159+
p1, p2 = mpc
160+
return p1, p2
161+
162+
def calc_position(self, d1=None, d2=None, param=None):
163+
d1, d2 = self.calc_displacement_map(d1, d2, param)
164+
return d1 * self.parent_detector.pixel1, d2 * self.parent_detector.pixel2
165+
166+
167+
class MultiModule:
168+
"""Split a detector in several modules"""
169+
170+
def __init__(self):
171+
self.modules = {} # this is contains all of modules
172+
self.lmask = None
173+
self.detector = None
174+
self.nb_modules = 0
175+
176+
def __repr__(self):
177+
return f"MultiModule with {self.nb_modules} modules:\n" + "\n".join(
178+
f" {i:2d}: {j}" for i, j in self.modules.items()
179+
)
180+
181+
def build_labels(self):
182+
self.lmask, self.nb_modules = ndimage.label(
183+
numpy.logical_not(self.detector.mask)
184+
)
185+
186+
@classmethod
187+
def from_detector(cls, detector):
188+
"""Alternative constructor
189+
190+
:param detector: ensure the mask is definied"""
191+
self = cls()
192+
if detector.mask is None:
193+
raise RuntimeError("`detector` must provide an actual mask")
194+
self.detector = detector
195+
self.build_labels()
196+
for l in range(1, self.nb_modules + 1): # noqa: E741
197+
self.modules[l] = SingleModule(detector, self.lmask, index=l, fixed=False)
198+
return self
199+
200+
@property
201+
def shape(self):
202+
return self.detector.shape
203+
204+
def calc_displacement_map(self):
205+
p1, p2, _ = self.detector.calc_cartesian_positions()
206+
p1 /= self.detector.pixel1
207+
p2 /= self.detector.pixel2
208+
209+
for l in range(1, self.nb_modules + 1): # noqa: E741
210+
m = self.modules[l]
211+
mp1, mp2 = m.calc_displacement_map()
212+
p1[m.mask] = mp1[m.mask]
213+
p2[m.mask] = mp2[m.mask]
214+
215+
return p1, p2
216+
217+
@property
218+
def free_modules(self):
219+
return sum(not m.fixed for m in self.modules.values())
220+
221+
222+
class MultiModuleRefinement(MultiModule):
223+
def __init__(self):
224+
super().__init__()
225+
self.modulated_points = {} # key: npt filename, value record array with coordinates, ring & module
226+
self.calibrants = {} # contains the different calibrant objects for each control-point file
227+
self._q_theo = {}
228+
self.ponis = {} # relative to control-point files #Unused ?
229+
230+
def calc_cp_positions(self, param=None, key=None, center=True):
231+
"""Calculate the physical position for control points of a given registered calibrant"""
232+
mcp = self.modulated_points[key]
233+
p1 = mcp.d0.copy()
234+
p2 = mcp.d1.copy()
235+
param_idx = 0
236+
center = 0.5 if center else 0
237+
for l in range(1, self.nb_modules + 1): # noqa: E741
238+
m = self.modules[l]
239+
mask = mcp.module == l
240+
valid = mcp[mask]
241+
sub_param = (
242+
None
243+
if param is None or m.fixed
244+
else ModuleParam(*param[3 * param_idx : 3 * (param_idx + 1)])
245+
)
246+
param_idx += 0 if m.fixed else 1
247+
mp1, mp2 = m.calc_position(
248+
d1=valid.d0 + center, d2=valid.d1 + center, param=sub_param
249+
)
250+
p1[mask] = mp1
251+
p2[mask] = mp2
252+
return p1, p2
253+
254+
def print_control_points_per_module(self, filename):
255+
if filename not in self.modulated_points:
256+
print(f"No control-point file named {filename}. Did you load it ?")
257+
else:
258+
print(filename, ":", self.calibrants.get(filename))
259+
modulated_cp = self.modulated_points[filename]
260+
for l in range(1, self.nb_modules + 1): # noqa: E741
261+
print(l, (modulated_cp.module == l).sum())
262+
263+
def load_control_points(self, filename, poni=None, verbose=False):
264+
"""
265+
:param filename: file with control points
266+
:param poni: file with the (uncorrected) detector position
267+
:param verbose: set to True to print out the number of control points per module
268+
"""
269+
cp = ControlPoints(filename)
270+
self.calibrants[filename] = cp.calibrant
271+
if poni:
272+
self.ponis[filename] = PoniFile(poni)
273+
# build modulated list of control points
274+
d0 = []
275+
d1 = []
276+
ring = []
277+
modules = []
278+
for i in cp.getList():
279+
d0.append(i[0])
280+
d1.append(i[1])
281+
ring.append(i[2])
282+
modules.append(0)
283+
modulated_cp = numpy.rec.fromarrays((d0, d1, ring, modules), dtype=module_d)
284+
linear = numpy.round(modulated_cp.d0).astype(numpy.int32) * self.shape[
285+
-1
286+
] + numpy.round(modulated_cp.d1).astype(numpy.int32)
287+
modulated_cp.module = self.lmask.ravel()[linear]
288+
self.modulated_points[filename] = modulated_cp
289+
if verbose:
290+
self.print_control_points_per_module(filename)
291+
292+
def init_q_theo(self, force=False):
293+
if force or not self._q_theo:
294+
self._q_theo = {
295+
key: 20.0
296+
* pi
297+
/ numpy.array(calibrant.dspacing)[self.modulated_points[key].ring]
298+
for key, calibrant in self.calibrants.items()
299+
}
300+
301+
def residu(self, param=None):
302+
"""Calculate the delta_q value between the expected ring position and the actual one"""
303+
if not self._q_theo:
304+
self.init_q_theo()
305+
module_param = param[
306+
: ModuleParam.nb_param * sum(not m.fixed for m in self.modules.values())
307+
]
308+
delta = []
309+
for idx, (key, calibrant) in enumerate(self.calibrants.items()):
310+
# print(key)
311+
tmp_e = self._q_theo[
312+
key
313+
] # This is the theoritical q_value for the given ring (in nm^-1)
314+
# print("exp", len(tmp_e), tmp_e)
315+
dp1, dp2 = self.calc_cp_positions(param=module_param, key=key)
316+
# print("dp", len(dp1), len(dp2))
317+
318+
start_idx = (
319+
ModuleParam.nb_param * self.free_modules + idx * PoniParam.nb_param
320+
)
321+
end_idx = start_idx + PoniParam.nb_param
322+
poni_param = PoniParam(*param[start_idx:end_idx])
323+
# print(poni_param)
324+
tmp_c = _geometry.calc_q(
325+
poni_param.dist,
326+
poni_param.rot1,
327+
poni_param.rot2,
328+
0.0,
329+
dp1 - poni_param.poni1,
330+
dp2 - poni_param.poni2,
331+
calibrant.wavelength,
332+
)
333+
# print("residu", tmp_e, tmp_c)
334+
delta.append(tmp_c - tmp_e)
335+
return numpy.concatenate(delta)
336+
337+
@property
338+
def nb_param(self):
339+
"""Number of parameters for the refinement"""
340+
free = sum(not m.fixed for m in self.modules.values())
341+
return free * ModuleParam.nb_param + PoniParam.nb_param * len(self.calibrants)
342+
343+
def init_param(self):
344+
"""Generate the numpy array with all parameters"""
345+
param = numpy.zeros(self.nb_param)
346+
idx = 0
347+
for m in self.modules.values():
348+
if m.fixed:
349+
continue
350+
for i, n in enumerate(ModuleParam.__dataclass_fields__, start=idx):
351+
param[i] = m.param.__getattribute__(n)
352+
idx += ModuleParam.nb_param
353+
for p in self.ponis.values():
354+
for i, n in enumerate(PoniParam.__dataclass_fields__, start=idx):
355+
param[i] = p.__getattribute__(n)
356+
idx += PoniParam.nb_param
357+
return param
358+
359+
def print_param(self, param):
360+
idx = 0
361+
for i, m in self.modules.items():
362+
if m.fixed:
363+
print(f"module #{i:2d}: Fixed")
364+
else:
365+
res = f"module #{i:2d}:"
366+
for i, n in enumerate(ModuleParam.__dataclass_fields__, start=idx):
367+
res += f" {n:5s}= {param[i]},"
368+
idx += ModuleParam.nb_param
369+
print(res)
370+
for p in self.ponis:
371+
res = f"{p}:"
372+
for i, n in enumerate(PoniParam.__dataclass_fields__, start=idx):
373+
res += f" {n:5s}= {param[i]:6f},"
374+
print(res)
375+
idx += PoniParam.nb_param
376+
377+
def cost(self, param):
378+
delta = self.residu(param)
379+
return numpy.dot(delta, delta)
380+
381+
def refine(self, param, method="SLSQP"):
382+
method = "Nelder-Mead" if method.lower() == "simplex" else method
383+
return optimize.minimize(self.cost, param, method=method)

0 commit comments

Comments
 (0)