@@ -145,17 +145,18 @@ def indices(ints_st[::1] element, ints_st[::1] test_element, ints_st offset=0,
145145@ cython.boundscheck (False )
146146@ cython.wraparound (False )
147147@ cython.initializedcheck (False )
148- def indices_in_cylinder (floats_st[:, ::1] dxyz , R , const floats_st h ):
149- """ Indices for all coordinates that are within a cylinde radius `R` and height `h`
148+ def indices_in_cylinder (floats_st[:, ::1] dxyz , const floats_st R , const floats_st h ):
149+ """ Indices for all coordinates that are within a cylinder radius `R` and height `h`
150150
151151 Parameters
152152 ----------
153153 dxyz :
154- coordinates centered around the cylinder
154+ coordinates centered around the cylinder.
155+ The last axis is the cylinder height.
155156 R :
156- radius of cylinder to check
157+ radius of cylinder to check.
157158 h :
158- height of cylinder to check
159+ height of cylinder to check.
159160
160161 Returns
161162 -------
@@ -168,34 +169,22 @@ def indices_in_cylinder(floats_st[:, ::1] dxyz, R, const floats_st h):
168169 cdef ndarray[int32_t] IDX = np.empty([n], dtype = np.int32)
169170 cdef int [::1 ] idx = IDX
170171
171- cdef floats_st Rx, Ry
172- cdef floats_st L2
172+ cdef floats_st R2
173+ cdef floats_st hhalve, L2
173174 cdef Py_ssize_t i, j, m
174175
175176 # Handle radius input
176- if isinstance (R, (float , int )):
177- Rx = Ry = R
178- else :
179- R = np.asarray(R, dtype = np.float64)
180- if R.shape == (2 ,):
181- Rx, Ry = R[0 ], R[1 ]
182- elif R.shape == (2 ,3 ):
183- from sisl._core._dtypes import fnorm
184- Rx = fnorm(R[0 ])
185- Ry = fnorm(R[1 ])
186- else :
187- raise ValueError (f" Unsupported radius shape: {R.shape}" )
188-
189- cdef floats_st Rx2 = Rx * Rx
190- cdef floats_st Ry2 = Ry * Ry
177+ R2 = R * R
178+ hhalve = h / 2
191179
192180 # Reset number of elements
193181 m = 0
194182
195183 with nogil:
196184 for i in range (n):
197- if dxyz[i,nxyz] > 0.5 * h or dxyz[i,nxyz] < - 0.5 * h: continue
198- L2 = (dxyz[i, 0 ]* dxyz[i, 0 ])/ Rx2 + (dxyz[i, 1 ]* dxyz[i, 1 ])/ Ry2
185+ if dxyz[i,nxyz] > hhalve or dxyz[i,nxyz] < - hhalve: continue
186+ # Calculate the distance of the circle
187+ L2 = (dxyz[i, 0 ]* dxyz[i, 0 ])/ R2 + (dxyz[i, 1 ]* dxyz[i, 1 ])/ R2
199188 if L2 > 1.0 : continue
200189 idx[m] = < int > i
201190 m += 1
0 commit comments