Skip to content

Commit c02c052

Browse files
authored
Merge pull request #2701 from silx-kit/2602_copy_geometry_refinement
Fix the copy of a `GeometryRefinement` object
2 parents 3c37f90 + af1468e commit c02c052

File tree

2 files changed

+83
-44
lines changed

2 files changed

+83
-44
lines changed

src/pyFAI/geometryRefinement.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
__contact__ = "Jerome.Kieffer@ESRF.eu"
3232
__license__ = "MIT"
3333
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
34-
__date__ = "31/10/2025"
34+
__date__ = "27/11/2025"
3535
__status__ = "development"
3636

3737
import os
@@ -42,11 +42,12 @@
4242
import numpy
4343
import math
4444
from math import pi
45+
from scipy.optimize import fmin, leastsq, fmin_slsqp
4546
from .integrator.azimuthal import AzimuthalIntegrator
4647
from .calibrant import Calibrant, CALIBRANT_FACTORY
4748
from .utils.ellipse import fit_ellipse
48-
from .utils.decorators import deprecated
49-
from scipy.optimize import fmin, leastsq, fmin_slsqp
49+
from .utils.decorators import deprecated, deprecated_args
50+
5051

5152
logger = logging.getLogger(__name__)
5253

@@ -71,8 +72,26 @@
7172

7273

7374
class GeometryRefinement(AzimuthalIntegrator):
75+
_IMMUTABLE_ATTRS = AzimuthalIntegrator._IMMUTABLE_ATTRS + (
76+
"_dist_min",
77+
"_dist_max",
78+
"_poni1_min",
79+
"_poni1_max",
80+
"_poni2_min",
81+
"_poni2_max",
82+
"_rot1_min",
83+
"_rot1_max",
84+
"_rot2_min",
85+
"_rot2_max",
86+
"_rot3_min",
87+
"_rot3_max",
88+
"_wavelength_min",
89+
"_wavelength_max",
90+
)
91+
7492
PARAM_ORDER = ("dist", "poni1", "poni2", "rot1", "rot2", "rot3", "wavelength")
7593

94+
@deprecated_args({"splinefile":"splineFile"}, since_version="2025.10")
7695
def __init__(
7796
self,
7897
data=None,
@@ -85,7 +104,7 @@ def __init__(
85104
rot3=0,
86105
pixel1=None,
87106
pixel2=None,
88-
splineFile=None,
107+
splinefile=None,
89108
detector=None,
90109
wavelength=None,
91110
**kwargs,
@@ -105,7 +124,7 @@ def __init__(
105124
:param rot3: guessed tilt of the detector around the incoming beam axis (optional, in rad)
106125
:param pixel1: Pixel size along the vertical direction of the detector (in m), almost mandatory
107126
:param pixel2: Pixel size along the horizontal direction of the detector (in m), almost mandatory
108-
:param splineFile: file describing the detector as 2 cubic splines. Replaces pixel1 & pixel2
127+
:param splinefile: file describing the detector as 2 cubic splines. Replaces pixel1 & pixel2
109128
:param detector: name of the detector or Detector instance. Replaces splineFile, pixel1 & pixel2
110129
:param wavelength: wavelength in m (1.54e-10)
111130
@@ -129,7 +148,7 @@ def __init__(
129148
if (
130149
(pixel1 is None)
131150
and (pixel2 is None)
132-
and (splineFile is None)
151+
and (splinefile is None)
133152
and (detector is None)
134153
):
135154
raise RuntimeError(
@@ -144,7 +163,7 @@ def __init__(
144163
rot3,
145164
pixel1,
146165
pixel2,
147-
splineFile,
166+
splinefile,
148167
detector,
149168
wavelength=wavelength,
150169
**kwargs,
@@ -189,9 +208,23 @@ def __init__(
189208
self._wavelength_min = 1e-15
190209
self._wavelength_max = 100.0e-10
191210

211+
def __copy__(self):
212+
""":return: a shallow copy of itself."""
213+
new = self.__class__(data=self.data,
214+
detector=self.detector,
215+
calibrant = self.calibrant)
216+
for key in self._IMMUTABLE_ATTRS:
217+
new.__setattr__(key, self.__getattribute__(key))
218+
new.param = [new.__getattribute__(key) for key in self.PARAM_ORDER]
219+
new._cached_array = self._cached_array.copy()
220+
return new
221+
192222
def __deepcopy__(self, memo=None):
193-
if memo is None:
194-
memo = {}
223+
"""deep copy helper function
224+
225+
:param memo: dict with modified objects
226+
:return: a deep copy of itself."""
227+
195228
data = copy.deepcopy(self.data, memo=memo)
196229
dist = copy.deepcopy(self._dist, memo=memo)
197230
poni1 = copy.deepcopy(self._poni1, memo=memo)
@@ -201,7 +234,7 @@ def __deepcopy__(self, memo=None):
201234
rot3 = copy.deepcopy(self._rot3, memo=memo)
202235
pixel1 = copy.deepcopy(self.detector.pixel1, memo=memo)
203236
pixel2 = copy.deepcopy(self.detector.pixel2, memo=memo)
204-
splineFile = copy.deepcopy(self.detector.splineFile, memo=memo)
237+
splinefile = copy.deepcopy(self.detector.splinefile, memo=memo)
205238
detector = copy.deepcopy(self.detector, memo=memo)
206239
wavelength = copy.deepcopy(self.wavelength, memo=memo)
207240
calibrant = copy.deepcopy(self.calibrant, memo=memo)
@@ -216,45 +249,17 @@ def __deepcopy__(self, memo=None):
216249
rot3=rot3,
217250
pixel1=pixel1,
218251
pixel2=pixel2,
219-
splineFile=splineFile,
252+
splinefile=splinefile,
220253
detector=detector,
221254
wavelength=wavelength,
222255
calibrant=calibrant,
223256
)
224-
numerical = [
225-
"_dist",
226-
"_poni1",
227-
"_poni2",
228-
"_rot1",
229-
"_rot2",
230-
"_rot3",
231-
"chiDiscAtPi",
232-
"_dssa_order",
233-
"_wavelength",
234-
"_oversampling",
235-
"_correct_solid_angle_for_spline",
236-
"_transmission_normal",
237-
"_dist_min",
238-
"_dist_max",
239-
"_poni1_min",
240-
"_poni1_max",
241-
"_poni2_min",
242-
"_poni2_max",
243-
"_rot1_min",
244-
"_rot1_max",
245-
"_rot2_min",
246-
"_rot2_max",
247-
"_rot3_min",
248-
"_rot3_max",
249-
"_wavelength_min",
250-
"_wavelength_max",
251-
]
252257
memo[id(self)] = new
253-
for key in numerical:
258+
for key in self._IMMUTABLE_ATTRS:
254259
old_value = self.__getattribute__(key)
255260
memo[id(old_value)] = old_value
256261
new.__setattr__(key, old_value)
257-
new_param = [new._dist, new._poni1, new._poni2, new._rot1, new._rot2, new._rot3]
262+
new_param = [new.__getattribute__(key) for key in self.PARAM_ORDER]
258263
memo[id(self.param)] = new_param
259264
new.param = new_param
260265
cached = {}

src/pyFAI/test/test_geometry_refinement.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@
3232
__contact__ = "Jerome.Kieffer@ESRF.eu"
3333
__license__ = "MIT"
3434
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
35-
__date__ = "10/10/2025"
35+
__date__ = "27/11/2025"
3636

3737
import unittest
3838
import os
3939
import numpy
4040
import random
4141
import logging
42+
import copy
4243
from .utilstest import UtilsTest
4344
from .. import geometryRefinement
4445
from .. import calibrant
@@ -839,9 +840,11 @@ def test_synthetic(self):
839840
mycalibrant.wavelength = 1e-10
840841
r2 = GeometryRefinement(data, calibrant=mycalibrant, detector="Fairchild",
841842
wavelength=mycalibrant.wavelength)
843+
r3 = copy.copy(r2)
844+
r4 = copy.deepcopy(r2)
842845
# print(r2)
843846
r2.guess_poni()
844-
# print(r2)
847+
#print(r2)
845848
r2.refine2(10000000, fix=[])
846849
ref = {"dist": (0.1, 1e-5), # value, tolerance
847850
"poni1": (0.05, 1e-5),
@@ -850,11 +853,42 @@ def test_synthetic(self):
850853
"poni2": (0.06, 1e-5),
851854
"rot1": (0.07, 1e-4),
852855
"wavelength": (1e-10, 1e-10)}
853-
# print(r2)
856+
print(r2)
854857
for key in ref.keys():
855858
self.assertAlmostEqual(ref[key][0], r2.__getattribute__(key), delta=ref[key][1],
856859
msg="%s is %s, I expected %s%s%s" % (key, r2.__getattribute__(key), ref[key], os.linesep, r2))
857860

861+
# test the copy
862+
self.assertEqual(r3.calibrant, r2.calibrant)
863+
self.assertTrue(numpy.all(r3.data==r2.data))
864+
r3.guess_poni()
865+
r3.refine2(10000000, fix=[])
866+
for k in r2._IMMUTABLE_ATTRS:
867+
self.assertEqual(r3.__getattribute__(k), r2.__getattribute__(k), k)
868+
for key in ref.keys():
869+
self.assertAlmostEqual(r3.__getattribute__(key), r2.__getattribute__(key), delta=ref[key][1],
870+
msg="%s is %s, I expected %s%s%s" % (key, r3.__getattribute__(key), ref[key], os.linesep, r3))
871+
872+
# test the deep-copy
873+
self.assertEqual(r4.calibrant, r2.calibrant)
874+
self.assertTrue(numpy.all(r4.data == r2.data))
875+
r4.guess_poni()
876+
r4.refine2(10000000, fix=[])
877+
for k in r2._IMMUTABLE_ATTRS:
878+
self.assertEqual(r4.__getattribute__(k), r2.__getattribute__(k), k)
879+
880+
881+
for key in ref.keys():
882+
self.assertAlmostEqual(r4.__getattribute__(key), r2.__getattribute__(key), delta=ref[key][1],
883+
msg="%s is %s, I expected %s%s%s" % (key, r4.__getattribute__(key), ref[key], os.linesep, r4))
884+
885+
# Mutation check, done last:
886+
r4.data[...] = 4
887+
self.assertFalse(numpy.all(r4.data == r2.data))
888+
# works also because data are copied in constructor ...
889+
r2.data[...] = 2
890+
self.assertFalse(numpy.all(r3.data == r2.data))
891+
858892

859893
def suite():
860894
testsuite = unittest.TestSuite()

0 commit comments

Comments
 (0)