Skip to content

Commit f40fe2a

Browse files
committed
Faster conversion from numpy array to matrix mod 2
1 parent c9dd1e8 commit f40fe2a

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

src/sage/matrix/matrix_mod2_dense.pyx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ from cysignals.memory cimport check_malloc, sig_free
108108
from cysignals.signals cimport sig_on, sig_str, sig_off
109109

110110
cimport sage.matrix.matrix_dense as matrix_dense
111-
from sage.matrix.args cimport SparseEntry, MatrixArgs_init
111+
from sage.matrix.args cimport SparseEntry, MatrixArgs_init, MA_ENTRIES_NDARRAY
112112
from libc.stdio cimport *
113113
from sage.structure.element cimport (Matrix, Vector)
114114
from sage.modules.free_module_element cimport FreeModuleElement
@@ -257,8 +257,25 @@ cdef class Matrix_mod2_dense(matrix_dense.Matrix_dense): # dense or sparse
257257
[]
258258
sage: Matrix(GF(2),0,2)
259259
[]
260+
261+
Make sure construction from numpy array is reasonably fast::
262+
263+
sage: # needs numpy
264+
sage: import numpy as np
265+
sage: n = 5000
266+
sage: M = matrix(GF(2), np.random.randint(0, 2, (n, n))) # around 700ms
267+
268+
Unsupported numpy data types (slower but still works)::
269+
270+
sage: # needs numpy
271+
sage: n = 100
272+
sage: M = matrix(GF(2), np.random.randint(0, 2, (n, n)).astype(np.float32))
260273
"""
261274
ma = MatrixArgs_init(parent, entries)
275+
if ma.get_type() == MA_ENTRIES_NDARRAY:
276+
from ..modules.numpy_util import set_matrix_mod2_from_numpy
277+
if set_matrix_mod2_from_numpy(self, ma.entries):
278+
return
262279
for t in ma.iter(coerce, True):
263280
se = <SparseEntry>t
264281
mzd_write_bit(self._entries, se.i, se.j, se.entry)

src/sage/modules/numpy_util.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from libc.stdint cimport uintptr_t
22
from sage.libs.m4ri cimport *
3+
from sage.matrix.matrix_mod2_dense cimport Matrix_mod2_dense
4+
5+
cpdef int set_matrix_mod2_from_numpy(Matrix_mod2_dense a, b) except -1
36

47
cpdef int set_mzd_from_numpy(uintptr_t entries_addr, Py_ssize_t degree, x) except -1
58
# Note: we don't actually need ``cimport`` to work, which means this header file is not used in practice

src/sage/modules/numpy_util.pyx

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# sage.doctest: optional - numpy
2+
# cython: fast_getattr=False
3+
# https://github.com/cython/cython/issues/6442
24
r"""
35
Utility functions for numpy.
46
"""
57

68
cimport numpy as np
79
import numpy as np
8-
from sage.libs.m4ri cimport *
9-
from libc.stdint cimport uintptr_t
1010

1111

1212
ctypedef fused numpy_integral:
@@ -64,3 +64,32 @@ cpdef int set_mzd_from_numpy(uintptr_t entries_addr, Py_ssize_t degree, x) excep
6464
mzd_write_bit(entries, 0, i, x_bool[i])
6565
return True
6666
return False
67+
68+
69+
cpdef int _set_matrix_mod2_from_numpy_helper(Matrix_mod2_dense a, np.ndarray[numpy_integral, ndim=2] b) except -1:
70+
"""
71+
Internal function, helper for :func:`set_matrix_mod2_from_numpy`.
72+
"""
73+
if not (a.nrows() == b.shape[0] and a.ncols() == b.shape[1]):
74+
raise ValueError("shape mismatch")
75+
for i in range(b.shape[0]):
76+
for j in range(b.shape[1]):
77+
a.set_unsafe_int(i, j, b[i, j] & 1)
78+
return True
79+
80+
81+
cpdef int set_matrix_mod2_from_numpy(Matrix_mod2_dense a, b) except -1:
82+
"""
83+
Try to set the entries of a matrix from a numpy array.
84+
85+
INPUT:
86+
87+
- ``a`` -- the destination matrix
88+
- ``b`` -- a numpy array, must have dimension 2 and the same shape as ``a``
89+
90+
OUTPUT: ``True`` if successful, ``False`` otherwise. May throw ``ValueError``.
91+
"""
92+
try:
93+
return (<object>_set_matrix_mod2_from_numpy_helper)(a, b) # https://github.com/cython/cython/issues/6588
94+
except TypeError:
95+
return False

0 commit comments

Comments
 (0)