Skip to content

Commit d8ed320

Browse files
committed
simplified changes and speeded it up
1 parent d815517 commit d8ed320

File tree

3 files changed

+31
-28
lines changed

3 files changed

+31
-28
lines changed

src/sisl/_indices.pyx

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,18 @@ def indices(ints_st[::1] element, ints_st[::1] test_element, ints_st offset=0,
145145
@cython.boundscheck(False)
146146
@cython.wraparound(False)
147147
@cython.initializedcheck(False)
148-
def indices_in_cylinder(floats_st[:, ::1] dxyz, R, const floats_st h):
149-
""" Indices for all coordinates that are within a cylinde radius `R` and height `h`
148+
def indices_in_cylinder(floats_st[:, ::1] dxyz, const floats_st R, const floats_st h):
149+
""" Indices for all coordinates that are within a cylinder radius `R` and height `h`
150150
151151
Parameters
152152
----------
153153
dxyz :
154-
coordinates centered around the cylinder
154+
coordinates centered around the cylinder.
155+
The last axis is the cylinder height.
155156
R :
156-
radius of cylinder to check
157+
radius of cylinder to check.
157158
h :
158-
height of cylinder to check
159+
height of cylinder to check.
159160
160161
Returns
161162
-------
@@ -168,34 +169,22 @@ def indices_in_cylinder(floats_st[:, ::1] dxyz, R, const floats_st h):
168169
cdef ndarray[int32_t] IDX = np.empty([n], dtype=np.int32)
169170
cdef int[::1] idx = IDX
170171

171-
cdef floats_st Rx, Ry
172-
cdef floats_st L2
172+
cdef floats_st R2
173+
cdef floats_st hhalve, L2
173174
cdef Py_ssize_t i, j, m
174175

175176
# Handle radius input
176-
if isinstance(R, (float, int)):
177-
Rx = Ry = R
178-
else:
179-
R = np.asarray(R, dtype=np.float64)
180-
if R.shape == (2,):
181-
Rx, Ry = R[0], R[1]
182-
elif R.shape == (2,3):
183-
from sisl._core._dtypes import fnorm
184-
Rx = fnorm(R[0])
185-
Ry = fnorm(R[1])
186-
else:
187-
raise ValueError(f"Unsupported radius shape: {R.shape}")
188-
189-
cdef floats_st Rx2 = Rx * Rx
190-
cdef floats_st Ry2 = Ry * Ry
177+
R2 = R * R
178+
hhalve = h / 2
191179

192180
# Reset number of elements
193181
m = 0
194182

195183
with nogil:
196184
for i in range(n):
197-
if dxyz[i,nxyz] > 0.5*h or dxyz[i,nxyz] < -0.5*h: continue
198-
L2 = (dxyz[i, 0]*dxyz[i, 0])/Rx2 + (dxyz[i, 1]*dxyz[i, 1])/Ry2
185+
if dxyz[i,nxyz] > hhalve or dxyz[i,nxyz] < -hhalve: continue
186+
# Calculate the distance of the circle
187+
L2 = (dxyz[i, 0]*dxyz[i, 0])/R2 + (dxyz[i, 1]*dxyz[i, 1])/R2
199188
if L2 > 1.0: continue
200189
idx[m] = <int> i
201190
m += 1

src/sisl/shape/_cylinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def within_index(self, other, rtol: float = 1.0e-8):
191191
# Get indices where we should do the more
192192
# expensive exact check of being inside shape
193193
# I.e. this reduces the search space to the box
194-
return indices_in_cylinder(tmp, 1.0 + rtol, 1.0 + rtol)
194+
return indices_in_cylinder(tmp, 1.0 + rtol)
195195

196196
@deprecation(
197197
"toSphere is deprecated, use shape.to.Sphere(...) instead.", "0.15", "0.17"

src/sisl/shape/tests/test_cylinder.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,24 @@ def test_create_ellipticalcylinder():
3939

4040
def test_ellipticalcylinder_within():
4141
el = EllipticalCylinder(1.0, 1.0)
42-
# center of cylinder
43-
assert el.within_index([0, 0, 0])[0] == 0
42+
# points in an ellipsis
43+
points = [
44+
[0, 0, 0],
45+
[0, 0, 0.5],
46+
[0, 0, -0.5],
47+
[1, 0, -0.5],
48+
[0, 1, -0.5],
49+
]
50+
assert len(el.within_index(points)) == len(points)
51+
4452
# should not be in a circle
45-
assert el.within_index([0.2, 0.2, 0.9])[0] == 0
53+
points = [
54+
[0, 0, 0.6],
55+
[0, 0, -0.6],
56+
[0.2, 0.2, 0.9],
57+
[0.2, 0.2, -0.9],
58+
]
59+
assert len(el.within_index(points)) == 0
4660

4761

4862
def test_tosphere():

0 commit comments

Comments
 (0)