1010from functools import wraps
1111from importlib .util import find_spec
1212from shutil import which
13- from typing import TYPE_CHECKING
13+ from typing import TYPE_CHECKING , TypeVar
1414from unittest import SkipTest
1515
1616import matplotlib
1717import numpy as np
18+ import numpy .typing as npt
1819from more_itertools import always_iterable
1920from numpy .random import RandomState
2021from unyt .exceptions import UnitOperationError
@@ -91,8 +92,8 @@ def assert_rel_equal(a1, a2, decimals, err_msg="", verbose=True):
9192
9293# tested: volume integral is 1.
9394def cubicspline_python (
94- x : float | np .ndarray ,
95- ) -> np .ndarray :
95+ x : float | npt . NDArray [ np .floating ] ,
96+ ) -> npt . NDArray [ np .floating ] :
9697 """
9798 cubic spline SPH kernel function for testing against more
9899 effiecient cython methods
@@ -118,8 +119,12 @@ def cubicspline_python(
118119
119120
120121def integrate_kernel (
121- kernelfunc : Callable [[float ], float ], b : float , hsml : float
122- ) -> float :
122+ kernelfunc : Callable [
123+ [float | npt .NDArray [np .floating ]], float | npt .NDArray [np .floating ]
124+ ],
125+ b : float | npt .NDArray [np .floating ],
126+ hsml : float | npt .NDArray [np .floating ],
127+ ) -> npt .NDArray [np .floating ]:
123128 """
124129 integrates a kernel function over a line passing entirely
125130 through it
@@ -147,18 +152,21 @@ def integrate_kernel(
147152 dx = np .diff (xe , axis = 0 )
148153 spv = kernelfunc (np .sqrt (xc ** 2 + x ** 2 ))
149154 integral = np .sum (spv * dx , axis = 0 )
150- return pre * integral
155+ return np . atleast_1d ( pre * integral )
151156
152157
153158_zeroperiods = np .array ([0.0 , 0.0 , 0.0 ])
154159
155160
161+ _FloatingT = TypeVar ("_FloatingT" , bound = np .floating )
162+
163+
156164def distancematrix (
157- pos3_i0 : np . ndarray ,
158- pos3_i1 : np . ndarray ,
165+ pos3_i0 : npt . NDArray [ _FloatingT ] ,
166+ pos3_i1 : npt . NDArray [ _FloatingT ] ,
159167 periodic : tuple [bool , bool , bool ] = (True ,) * 3 ,
160- periods : np . ndarray = _zeroperiods ,
161- ) -> np . ndarray :
168+ periods : npt . NDArray [ _FloatingT ] = _zeroperiods ,
169+ ) -> npt . NDArray [ _FloatingT ] :
162170 """
163171 Calculates the distances between two arrays of points.
164172
0 commit comments