|
| 1 | +# !/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# Project: Azimuthal integration |
| 5 | +# https://github.com/silx-kit/pyFAI |
| 6 | +# |
| 7 | +# Copyright (C) 2025-2026 European Synchrotron Radiation Facility, Grenoble, France |
| 8 | +# |
| 9 | +# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu) |
| 10 | +# |
| 11 | +# Permission is hereby granted, free of charge, to any person obtaining a copy |
| 12 | +# of this software and associated documentation files (the "Software"), to deal |
| 13 | +# in the Software without restriction, including without limitation the rights |
| 14 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 15 | +# copies of the Software, and to permit persons to whom the Software is |
| 16 | +# furnished to do so, subject to the following conditions: |
| 17 | +# |
| 18 | +# The above copyright notice and this permission notice shall be included in |
| 19 | +# all copies or substantial portions of the Software. |
| 20 | +# |
| 21 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 22 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 23 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 24 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 25 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 26 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 27 | +# THE SOFTWARE. |
| 28 | +# |
| 29 | + |
| 30 | +"""Multi-module detectors: |
| 31 | +
|
| 32 | +This module contains some helper function to define a detector from several modules |
| 33 | +and later-on refine this module position from powder diffraction data |
| 34 | +as demonstrated in https://doi.org/10.3390/cryst12020255 |
| 35 | +""" |
| 36 | + |
| 37 | +__author__ = "Jérôme Kieffer" |
| 38 | +__contact__ = "Jerome.Kieffer@ESRF.eu" |
| 39 | +__license__ = "MIT" |
| 40 | +__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France" |
| 41 | +__date__ = "13/01/2026" |
| 42 | +__status__ = "development" |
| 43 | + |
| 44 | +from math import sin, cos, pi |
| 45 | +from dataclasses import dataclass |
| 46 | +import numpy |
| 47 | +from scipy import ndimage, optimize |
| 48 | +from ..control_points import ControlPoints |
| 49 | +from ..ext import _geometry |
| 50 | +from ..io.ponifile import PoniFile |
| 51 | +from ..third_party.classproperties import classproperty |
| 52 | + |
| 53 | + |
| 54 | +module_d = numpy.dtype( |
| 55 | + [ |
| 56 | + ("d0", numpy.float64), |
| 57 | + ("d1", numpy.float64), |
| 58 | + ("ring", numpy.int32), |
| 59 | + ("module", numpy.int32), |
| 60 | + ] |
| 61 | +) |
| 62 | + |
| 63 | + |
| 64 | +# Those are the optimizable parameters ... 2 translations and one rotation. |
| 65 | +@dataclass |
| 66 | +class ModuleParam: |
| 67 | + d0: float = 0.0 |
| 68 | + d1: float = 0.0 |
| 69 | + rot: float = 0.0 |
| 70 | + |
| 71 | + def set(self, iterable): |
| 72 | + self.d0, self.d1, self.rot = iterable[:3] |
| 73 | + |
| 74 | + def get(self): |
| 75 | + return (self.d0, self.d1, self.rot) |
| 76 | + |
| 77 | + @classproperty |
| 78 | + def nb_param(cls): |
| 79 | + return len(cls.__dataclass_fields__) |
| 80 | + |
| 81 | + |
| 82 | +@dataclass |
| 83 | +class PoniParam: |
| 84 | + dist: float = 0.0 |
| 85 | + poni1: float = 0.0 |
| 86 | + poni2: float = 0.0 |
| 87 | + rot1: float = 0.0 |
| 88 | + rot2: float = 0.0 |
| 89 | + # rot3:float=0.0 |
| 90 | + # wavelength:float=0.0 |
| 91 | + |
| 92 | + @classproperty |
| 93 | + def nb_param(self): |
| 94 | + return len(self.__dataclass_fields__) |
| 95 | + |
| 96 | + |
| 97 | +class SingleModule: |
| 98 | + def __init__(self, detector, mask, index=None, fixed=False): |
| 99 | + self.parent_detector = detector |
| 100 | + self.parent_index = index |
| 101 | + if (index is not None) and index <= mask.max(): |
| 102 | + self.mask = mask == index |
| 103 | + else: |
| 104 | + self.mask = mask |
| 105 | + self.fixed = False |
| 106 | + self.param = ModuleParam() |
| 107 | + self.center = None |
| 108 | + self.bounding_box = None |
| 109 | + self.calc_bounding_box() |
| 110 | + |
| 111 | + def __repr__(self): |
| 112 | + return ( |
| 113 | + f"Module centered at ({self.center[0, 0]:.1f}, {self.center[1, 0]:.1f})" |
| 114 | + + (", fixed." if self.fixed else ".") |
| 115 | + ) |
| 116 | + |
| 117 | + def calc_bounding_box(self): |
| 118 | + d0, d1 = numpy.where(self.mask) |
| 119 | + d0m = d0.min() |
| 120 | + d0M = d0.max() |
| 121 | + d1m = d1.min() |
| 122 | + d1M = d1.max() |
| 123 | + self.center = numpy.atleast_2d([0.5 * (d0M + d0m + 1), 0.5 * (d1M + d1m + 1)]).T |
| 124 | + self.bounding_box = (slice(d0m, d0M + 1), slice(d1m, d1M + 1)) |
| 125 | + return self.bounding_box |
| 126 | + |
| 127 | + def calc_displacement_map(self, d1=None, d2=None, param=None): |
| 128 | + if d1 is None and d2 is None: |
| 129 | + full_detector = True |
| 130 | + p1, p2, _ = self.parent_detector.calc_cartesian_positions() |
| 131 | + d1 = p1 / self.parent_detector.pixel1 |
| 132 | + d2 = p2 / self.parent_detector.pixel2 |
| 133 | + mp1 = d1[self.mask] |
| 134 | + mp2 = d2[self.mask] |
| 135 | + else: |
| 136 | + full_detector = False |
| 137 | + mp1 = d1 |
| 138 | + mp2 = d2 |
| 139 | + |
| 140 | + param = param or self.param |
| 141 | + |
| 142 | + mpc = numpy.vstack((mp1.ravel(), mp2.ravel())) |
| 143 | + if not self.fixed: |
| 144 | + self.center |
| 145 | + mpc -= self.center |
| 146 | + rot = param.rot |
| 147 | + c, s = cos(rot), sin(rot) |
| 148 | + rotm = numpy.array([[c, -s], [s, c]]) |
| 149 | + mpc = ( |
| 150 | + numpy.dot(rotm, mpc) |
| 151 | + + self.center |
| 152 | + + numpy.atleast_2d([param.d0, param.d1]).T |
| 153 | + ) |
| 154 | + if full_detector: |
| 155 | + mshape = mp1.shape |
| 156 | + p1[self.mask] = mpc[0].reshape(mshape) |
| 157 | + p2[self.mask] = mpc[1].reshape(mshape) |
| 158 | + else: |
| 159 | + p1, p2 = mpc |
| 160 | + return p1, p2 |
| 161 | + |
| 162 | + def calc_position(self, d1=None, d2=None, param=None): |
| 163 | + d1, d2 = self.calc_displacement_map(d1, d2, param) |
| 164 | + return d1 * self.parent_detector.pixel1, d2 * self.parent_detector.pixel2 |
| 165 | + |
| 166 | + |
| 167 | +class MultiModule: |
| 168 | + """Split a detector in several modules""" |
| 169 | + |
| 170 | + def __init__(self): |
| 171 | + self.modules = {} # this is contains all of modules |
| 172 | + self.lmask = None |
| 173 | + self.detector = None |
| 174 | + self.nb_modules = 0 |
| 175 | + |
| 176 | + def __repr__(self): |
| 177 | + return f"MultiModule with {self.nb_modules} modules:\n" + "\n".join( |
| 178 | + f" {i:2d}: {j}" for i, j in self.modules.items() |
| 179 | + ) |
| 180 | + |
| 181 | + def build_labels(self): |
| 182 | + self.lmask, self.nb_modules = ndimage.label( |
| 183 | + numpy.logical_not(self.detector.mask) |
| 184 | + ) |
| 185 | + |
| 186 | + @classmethod |
| 187 | + def from_detector(cls, detector): |
| 188 | + """Alternative constructor |
| 189 | +
|
| 190 | + :param detector: ensure the mask is definied""" |
| 191 | + self = cls() |
| 192 | + if detector.mask is None: |
| 193 | + raise RuntimeError("`detector` must provide an actual mask") |
| 194 | + self.detector = detector |
| 195 | + self.build_labels() |
| 196 | + for l in range(1, self.nb_modules + 1): # noqa: E741 |
| 197 | + self.modules[l] = SingleModule(detector, self.lmask, index=l, fixed=False) |
| 198 | + return self |
| 199 | + |
| 200 | + @property |
| 201 | + def shape(self): |
| 202 | + return self.detector.shape |
| 203 | + |
| 204 | + def calc_displacement_map(self): |
| 205 | + p1, p2, _ = self.detector.calc_cartesian_positions() |
| 206 | + p1 /= self.detector.pixel1 |
| 207 | + p2 /= self.detector.pixel2 |
| 208 | + |
| 209 | + for l in range(1, self.nb_modules + 1): # noqa: E741 |
| 210 | + m = self.modules[l] |
| 211 | + mp1, mp2 = m.calc_displacement_map() |
| 212 | + p1[m.mask] = mp1[m.mask] |
| 213 | + p2[m.mask] = mp2[m.mask] |
| 214 | + |
| 215 | + return p1, p2 |
| 216 | + |
| 217 | + @property |
| 218 | + def free_modules(self): |
| 219 | + return sum(not m.fixed for m in self.modules.values()) |
| 220 | + |
| 221 | + |
| 222 | +class MultiModuleRefinement(MultiModule): |
| 223 | + def __init__(self): |
| 224 | + super().__init__() |
| 225 | + self.modulated_points = {} # key: npt filename, value record array with coordinates, ring & module |
| 226 | + self.calibrants = {} # contains the different calibrant objects for each control-point file |
| 227 | + self._q_theo = {} |
| 228 | + self.ponis = {} # relative to control-point files #Unused ? |
| 229 | + |
| 230 | + def calc_cp_positions(self, param=None, key=None, center=True): |
| 231 | + """Calculate the physical position for control points of a given registered calibrant""" |
| 232 | + mcp = self.modulated_points[key] |
| 233 | + p1 = mcp.d0.copy() |
| 234 | + p2 = mcp.d1.copy() |
| 235 | + param_idx = 0 |
| 236 | + center = 0.5 if center else 0 |
| 237 | + for l in range(1, self.nb_modules + 1): # noqa: E741 |
| 238 | + m = self.modules[l] |
| 239 | + mask = mcp.module == l |
| 240 | + valid = mcp[mask] |
| 241 | + sub_param = ( |
| 242 | + None |
| 243 | + if param is None or m.fixed |
| 244 | + else ModuleParam(*param[3 * param_idx : 3 * (param_idx + 1)]) |
| 245 | + ) |
| 246 | + param_idx += 0 if m.fixed else 1 |
| 247 | + mp1, mp2 = m.calc_position( |
| 248 | + d1=valid.d0 + center, d2=valid.d1 + center, param=sub_param |
| 249 | + ) |
| 250 | + p1[mask] = mp1 |
| 251 | + p2[mask] = mp2 |
| 252 | + return p1, p2 |
| 253 | + |
| 254 | + def print_control_points_per_module(self, filename): |
| 255 | + if filename not in self.modulated_points: |
| 256 | + print(f"No control-point file named {filename}. Did you load it ?") |
| 257 | + else: |
| 258 | + print(filename, ":", self.calibrants.get(filename)) |
| 259 | + modulated_cp = self.modulated_points[filename] |
| 260 | + for l in range(1, self.nb_modules + 1): # noqa: E741 |
| 261 | + print(l, (modulated_cp.module == l).sum()) |
| 262 | + |
| 263 | + def load_control_points(self, filename, poni=None, verbose=False): |
| 264 | + """ |
| 265 | + :param filename: file with control points |
| 266 | + :param poni: file with the (uncorrected) detector position |
| 267 | + :param verbose: set to True to print out the number of control points per module |
| 268 | + """ |
| 269 | + cp = ControlPoints(filename) |
| 270 | + self.calibrants[filename] = cp.calibrant |
| 271 | + if poni: |
| 272 | + self.ponis[filename] = PoniFile(poni) |
| 273 | + # build modulated list of control points |
| 274 | + d0 = [] |
| 275 | + d1 = [] |
| 276 | + ring = [] |
| 277 | + modules = [] |
| 278 | + for i in cp.getList(): |
| 279 | + d0.append(i[0]) |
| 280 | + d1.append(i[1]) |
| 281 | + ring.append(i[2]) |
| 282 | + modules.append(0) |
| 283 | + modulated_cp = numpy.rec.fromarrays((d0, d1, ring, modules), dtype=module_d) |
| 284 | + linear = numpy.round(modulated_cp.d0).astype(numpy.int32) * self.shape[ |
| 285 | + -1 |
| 286 | + ] + numpy.round(modulated_cp.d1).astype(numpy.int32) |
| 287 | + modulated_cp.module = self.lmask.ravel()[linear] |
| 288 | + self.modulated_points[filename] = modulated_cp |
| 289 | + if verbose: |
| 290 | + self.print_control_points_per_module(filename) |
| 291 | + |
| 292 | + def init_q_theo(self, force=False): |
| 293 | + if force or not self._q_theo: |
| 294 | + self._q_theo = { |
| 295 | + key: 20.0 |
| 296 | + * pi |
| 297 | + / numpy.array(calibrant.dspacing)[self.modulated_points[key].ring] |
| 298 | + for key, calibrant in self.calibrants.items() |
| 299 | + } |
| 300 | + |
| 301 | + def residu(self, param=None): |
| 302 | + """Calculate the delta_q value between the expected ring position and the actual one""" |
| 303 | + if not self._q_theo: |
| 304 | + self.init_q_theo() |
| 305 | + module_param = param[ |
| 306 | + : ModuleParam.nb_param * sum(not m.fixed for m in self.modules.values()) |
| 307 | + ] |
| 308 | + delta = [] |
| 309 | + for idx, (key, calibrant) in enumerate(self.calibrants.items()): |
| 310 | + # print(key) |
| 311 | + tmp_e = self._q_theo[ |
| 312 | + key |
| 313 | + ] # This is the theoritical q_value for the given ring (in nm^-1) |
| 314 | + # print("exp", len(tmp_e), tmp_e) |
| 315 | + dp1, dp2 = self.calc_cp_positions(param=module_param, key=key) |
| 316 | + # print("dp", len(dp1), len(dp2)) |
| 317 | + |
| 318 | + start_idx = ( |
| 319 | + ModuleParam.nb_param * self.free_modules + idx * PoniParam.nb_param |
| 320 | + ) |
| 321 | + end_idx = start_idx + PoniParam.nb_param |
| 322 | + poni_param = PoniParam(*param[start_idx:end_idx]) |
| 323 | + # print(poni_param) |
| 324 | + tmp_c = _geometry.calc_q( |
| 325 | + poni_param.dist, |
| 326 | + poni_param.rot1, |
| 327 | + poni_param.rot2, |
| 328 | + 0.0, |
| 329 | + dp1 - poni_param.poni1, |
| 330 | + dp2 - poni_param.poni2, |
| 331 | + calibrant.wavelength, |
| 332 | + ) |
| 333 | + # print("residu", tmp_e, tmp_c) |
| 334 | + delta.append(tmp_c - tmp_e) |
| 335 | + return numpy.concatenate(delta) |
| 336 | + |
| 337 | + @property |
| 338 | + def nb_param(self): |
| 339 | + """Number of parameters for the refinement""" |
| 340 | + free = sum(not m.fixed for m in self.modules.values()) |
| 341 | + return free * ModuleParam.nb_param + PoniParam.nb_param * len(self.calibrants) |
| 342 | + |
| 343 | + def init_param(self): |
| 344 | + """Generate the numpy array with all parameters""" |
| 345 | + param = numpy.zeros(self.nb_param) |
| 346 | + idx = 0 |
| 347 | + for m in self.modules.values(): |
| 348 | + if m.fixed: |
| 349 | + continue |
| 350 | + for i, n in enumerate(ModuleParam.__dataclass_fields__, start=idx): |
| 351 | + param[i] = m.param.__getattribute__(n) |
| 352 | + idx += ModuleParam.nb_param |
| 353 | + for p in self.ponis.values(): |
| 354 | + for i, n in enumerate(PoniParam.__dataclass_fields__, start=idx): |
| 355 | + param[i] = p.__getattribute__(n) |
| 356 | + idx += PoniParam.nb_param |
| 357 | + return param |
| 358 | + |
| 359 | + def print_param(self, param): |
| 360 | + idx = 0 |
| 361 | + for i, m in self.modules.items(): |
| 362 | + if m.fixed: |
| 363 | + print(f"module #{i:2d}: Fixed") |
| 364 | + else: |
| 365 | + res = f"module #{i:2d}:" |
| 366 | + for i, n in enumerate(ModuleParam.__dataclass_fields__, start=idx): |
| 367 | + res += f" {n:5s}= {param[i]}," |
| 368 | + idx += ModuleParam.nb_param |
| 369 | + print(res) |
| 370 | + for p in self.ponis: |
| 371 | + res = f"{p}:" |
| 372 | + for i, n in enumerate(PoniParam.__dataclass_fields__, start=idx): |
| 373 | + res += f" {n:5s}= {param[i]:6f}," |
| 374 | + print(res) |
| 375 | + idx += PoniParam.nb_param |
| 376 | + |
| 377 | + def cost(self, param): |
| 378 | + delta = self.residu(param) |
| 379 | + return numpy.dot(delta, delta) |
| 380 | + |
| 381 | + def refine(self, param, method="SLSQP"): |
| 382 | + method = "Nelder-Mead" if method.lower() == "simplex" else method |
| 383 | + return optimize.minimize(self.cost, param, method=method) |
0 commit comments