Skip to content

Commit 79208a7

Browse files
committed
add to_ndarray method to QuantLib Array
1 parent bcc2754 commit 79208a7

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

quantlib/math/_array.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ cdef extern from 'ql/math/array.hpp' namespace 'QuantLib':
1717
Real& at(Size i) except +IndexError
1818
Size size()
1919
Real& operator[](Size)
20+
const Real* begin()

quantlib/math/array.pyx

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
FOR A PARTICULAR PURPOSE. See the license for more details.
88
"""
99

10-
include '../types.pxi'
10+
from quantlib.types cimport Real, Size
11+
from cpython.ref cimport Py_INCREF
1112
from libcpp.utility cimport move
13+
cimport numpy as np
14+
15+
np.import_array()
1216

1317
cdef class Array:
1418
"""
1519
1D array for linear algebra
1620
"""
1721

18-
def __init__(self, Size size, value=None):
22+
def __init__(self, Size size=0, value=None):
1923
if value is None:
2024
self._thisptr = move[_arr.Array](_arr.Array(size))
2125
else:
@@ -31,9 +35,16 @@ cdef class Array:
3135
raise IndexError("index {} is larger than the size of the array {}".
3236
format(key, self._thisptr.size()))
3337

34-
property size:
35-
def __get__(self):
36-
return self._thisptr.size()
38+
def __len__(self):
39+
return self._thisptr.size()
40+
41+
def to_ndarray(self):
42+
cdef np.npy_intp[1] dims
43+
dims[0] = self._thisptr.size()
44+
cdef arr = np.PyArray_SimpleNewFromData(1, &dims[0], np.NPY_DOUBLE, <void*>(self._thisptr.begin()))
45+
Py_INCREF(self)
46+
np.PyArray_SetBaseObject(arr, self)
47+
return arr
3748

3849
cpdef qlarray_from_pyarray(p):
3950
cdef Array x = Array(len(p))

test/test_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_array_1(self):
1010
x = Array(10, v)
1111

1212
self.assertEqual(x[4], v)
13-
self.assertEqual(x.size, 10)
13+
self.assertEqual(len(x), 10)
1414

1515
def test_array_2(self):
1616
v = 3.14

0 commit comments

Comments
 (0)