Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions mkl_umath/src/_patch_numpy.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ from libc.stdlib cimport free, malloc

cnp.import_umath()

cdef extern from *:
"""
#include "numpy/ufuncobject.h"
static inline char* _get_ufunc_types(PyObject *u) {
return (char *)((PyUFuncObject *)u)->types;
}
"""
char* _get_ufunc_types(object u) noexcept


ctypedef struct function_info:
cnp.PyUFuncGenericFunction original_function
cnp.PyUFuncGenericFunction patch_function
Expand All @@ -53,66 +63,74 @@ cdef class _patch_impl:
functions_dict = dict()

def __cinit__(self):
cdef int pi, oi
cdef int pi, oi, i, nargs
cdef int expected_count
cdef char* patch_types
cdef char* orig_types

umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
self.functions = NULL
self.functions_count = 0

umaths = [x for x in dir(mu) if isinstance(getattr(mu, x), np.ufunc)]
expected_count = 0
for umath in umaths:
mkl_umath_func = getattr(mu, umath)
self.functions_count += mkl_umath_func.ntypes
expected_count += mkl_umath_func.ntypes

self.functions = <function_info *> malloc(
self.functions_count * sizeof(function_info)
expected_count * sizeof(function_info)
)

func_number = 0
for umath in umaths:
patch_umath = getattr(mu, umath)
c_patch_umath = <cnp.ufunc>patch_umath
c_orig_umath = <cnp.ufunc>getattr(np, umath)
nargs = c_patch_umath.nargs
patch_types = _get_ufunc_types(c_patch_umath)
orig_types = _get_ufunc_types(c_orig_umath)
for pi in range(c_patch_umath.ntypes):
oi = 0
while oi < c_orig_umath.ntypes:
found = True
for i in range(c_patch_umath.nargs):
for i in range(nargs):
if (
c_patch_umath.types[pi * nargs + i]
!= c_orig_umath.types[oi * nargs + i]
patch_types[pi * nargs + i]
!= orig_types[oi * nargs + i]
):
found = False
break
if found is True:
break
oi = oi + 1
if oi < c_orig_umath.ntypes:
self.functions[func_number].original_function = (
self.functions[self.functions_count].original_function = (
c_orig_umath.functions[oi]
)
self.functions[func_number].patch_function = (
self.functions[self.functions_count].patch_function = (
c_patch_umath.functions[pi]
)
self.functions[func_number].signature = (
self.functions[self.functions_count].signature = (
<int *> malloc(nargs * sizeof(int))
)
for i in range(nargs):
self.functions[func_number].signature[i] = (
c_patch_umath.types[pi * nargs + i]
self.functions[self.functions_count].signature[i] = (
patch_types[pi * nargs + i]
)
self.functions_dict[(umath, patch_umath.types[pi])] = (
func_number
self.functions_count
)
func_number = func_number + 1
self.functions_count += 1
else:
raise RuntimeError(
f"Unable to find original function for: {umath} "
f"{patch_umath.types[pi]}"
)

def __dealloc__(self):
for i in range(self.functions_count):
free(self.functions[i].signature)
free(self.functions)
if self.functions is not NULL:
for i in range(self.functions_count):
free(self.functions[i].signature)
free(self.functions)

cdef int _replace_loop(
self,
Expand Down
Loading