Skip to content

Commit 0d2a7da

Browse files
authored
Merge pull request #67 from jcapriot/prism_cython
Prism cython
2 parents 5408065 + 3cf25a7 commit 0d2a7da

File tree

8 files changed

+396
-581
lines changed

8 files changed

+396
-581
lines changed

.github/workflows/test_with_conda.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
matplotlib
4040
jupyter
4141
utm
42+
numba
4243
pytest
4344
pytest-cov
4445
sphinx
@@ -87,6 +88,7 @@ jobs:
8788
matplotlib
8889
jupyter
8990
utm
91+
numba
9092
pytest
9193
pytest-cov
9294
sphinx

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ docs/api/generated/*
9898

9999
# Jupyter
100100
*.ipynb
101+
102+
#Cython generated files
101103
geoana/kernels/_extensions/rTE.cpp
104+
geoana/kernels/_extensions/potential_field_prism.c
105+
geoana/kernels/_extensions/potential_field_prism_api.h
102106

103107
# setuptools_scm
104108
geoana/version.py
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
try:
2+
# register numba jitable versions of the prism functions
3+
# if numba is available (and this module is installed).
4+
from numba.extending import (
5+
overload,
6+
get_cython_function_address
7+
)
8+
from numba import types
9+
import ctypes
10+
11+
from .potential_field_prism import (
12+
prism_f,
13+
prism_fz,
14+
prism_fzz,
15+
prism_fzx,
16+
prism_fzy,
17+
prism_fzzz,
18+
prism_fxxy,
19+
prism_fxxz,
20+
prism_fxyz,
21+
)
22+
funcs = [
23+
prism_f,
24+
prism_fz,
25+
prism_fzz,
26+
prism_fzx,
27+
prism_fzy,
28+
prism_fzzz,
29+
prism_fxxy,
30+
prism_fxxz,
31+
prism_fxyz,
32+
]
33+
34+
def _numba_register_prism_func(prism_func):
35+
module = 'geoana.kernels._extensions.potential_field_prism'
36+
name = prism_func.__name__
37+
38+
func_address = get_cython_function_address(module, name)
39+
func_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
40+
c_func = func_type(func_address)
41+
42+
@overload(prism_func)
43+
def numba_func(x, y, z):
44+
if isinstance(x, types.Float):
45+
if isinstance(y, types.Float):
46+
if isinstance(z, types.Float):
47+
def f(x, y, z):
48+
return c_func(x, y, z)
49+
return f
50+
for func in funcs:
51+
_numba_register_prism_func(func)
52+
53+
except ImportError as err:
54+
pass

0 commit comments

Comments
 (0)