Skip to content

Commit 00b4b3a

Browse files
authored
Merge pull request #948 from tanocalogero/main
Fixed bugs in EllipticalCylinder and indices_in_cylinder.
2 parents 4d05f7f + 362314e commit 00b4b3a

File tree

4 files changed

+37
-23
lines changed

4 files changed

+37
-23
lines changed

changes/947.fix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed finding coordinates within elliptical cylinders

src/sisl/_indices.pyx

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,17 @@ def indices(ints_st[::1] element, ints_st[::1] test_element, ints_st offset=0,
146146
@cython.wraparound(False)
147147
@cython.initializedcheck(False)
148148
def indices_in_cylinder(floats_st[:, ::1] dxyz, const floats_st R, const floats_st h):
149-
""" Indices for all coordinates that are within a cylinde radius `R` and height `h`
149+
""" Indices for all coordinates that are within a cylinder radius `R` and height `h`
150150
151151
Parameters
152152
----------
153153
dxyz :
154-
coordinates centered around the cylinder
154+
coordinates centered around the cylinder.
155+
The last axis is the cylinder height.
155156
R :
156-
radius of cylinder to check
157+
radius of cylinder to check.
157158
h :
158-
height of cylinder to check
159+
height of cylinder to check.
159160
160161
Returns
161162
-------
@@ -168,25 +169,23 @@ def indices_in_cylinder(floats_st[:, ::1] dxyz, const floats_st R, const floats_
168169
cdef ndarray[int32_t] IDX = np.empty([n], dtype=np.int32)
169170
cdef int[::1] idx = IDX
170171

171-
cdef floats_st R2 = R * R
172-
cdef floats_st L2
172+
cdef floats_st R2
173+
cdef floats_st hhalve, L2
173174
cdef Py_ssize_t i, j, m
174-
cdef bint skip
175+
176+
# Handle radius input
177+
R2 = R * R
178+
hhalve = h / 2
175179

176180
# Reset number of elements
177181
m = 0
178182

179183
with nogil:
180184
for i in range(n):
181-
skip = 0
182-
for j in range(nxyz):
183-
skip |= dxyz[i, j] > R
184-
if skip or dxyz[i, nxyz] > h: continue
185-
186-
L2 = 0.
187-
for j in range(nxyz):
188-
L2 += dxyz[i, j] * dxyz[i, j]
189-
if L2 > R2: continue
185+
if dxyz[i,nxyz] > hhalve or dxyz[i,nxyz] < -hhalve: continue
186+
# Calculate the distance of the circle
187+
L2 = (dxyz[i, 0]*dxyz[i, 0])/R2 + (dxyz[i, 1]*dxyz[i, 1])/R2
188+
if L2 > 1.0: continue
190189
idx[m] = <int> i
191190
m += 1
192191

src/sisl/shape/_cylinder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797
self._h = h
9898

9999
def copy(self) -> Self:
100-
return self.__class__(self.radial_vector, self.height, self.center)
100+
return self.__class__(self.radial_vector, self.height, center=self.center)
101101

102102
@property
103103
def volume(self) -> float:
@@ -140,7 +140,7 @@ def scale(self, scale: SeqOrScalarFloat) -> Self:
140140
else:
141141
v = self._v * scale
142142
h = self._h * scale
143-
return self.__class__(v, h, self.center)
143+
return self.__class__(v, h, center=self.center)
144144

145145
def expand(self, radius: SeqOrScalarFloat) -> Self:
146146
"""Expand elliptical cylinder by a constant value along each vector and height
@@ -163,7 +163,7 @@ def expand(self, radius: SeqOrScalarFloat) -> Self:
163163
raise ValueError(
164164
f"{self.__class__.__name__}.expand requires the radius to be either (1,) or (3,)"
165165
)
166-
return self.__class__([v0, v1], h, self.center)
166+
return self.__class__([v0, v1], h, center=self.center)
167167

168168
@deprecate_argument(
169169
"tol",
@@ -191,7 +191,7 @@ def within_index(self, other, rtol: float = 1.0e-8):
191191
# Get indices where we should do the more
192192
# expensive exact check of being inside shape
193193
# I.e. this reduces the search space to the box
194-
return indices_in_cylinder(tmp, 1.0 + rtol, 1.0 + rtol)
194+
return indices_in_cylinder(tmp, 1.0 + rtol)
195195

196196
@deprecation(
197197
"toSphere is deprecated, use shape.to.Sphere(...) instead.", "0.15", "0.17"

src/sisl/shape/tests/test_cylinder.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,24 @@ def test_create_ellipticalcylinder():
3939

4040
def test_ellipticalcylinder_within():
4141
el = EllipticalCylinder(1.0, 1.0)
42-
# center of cylinder
43-
assert el.within_index([0, 0, 0])[0] == 0
42+
# points in an ellipsis
43+
points = [
44+
[0, 0, 0],
45+
[0, 0, 0.5],
46+
[0, 0, -0.5],
47+
[1, 0, -0.5],
48+
[0, 1, -0.5],
49+
]
50+
assert len(el.within_index(points)) == len(points)
51+
4452
# should not be in a circle
45-
assert el.within_index([0.2, 0.2, 0.9])[0] == 0
53+
points = [
54+
[0, 0, 0.6],
55+
[0, 0, -0.6],
56+
[0.2, 0.2, 0.9],
57+
[0.2, 0.2, -0.9],
58+
]
59+
assert len(el.within_index(points)) == 0
4660

4761

4862
def test_tosphere():

0 commit comments

Comments
 (0)