Skip to content

Commit 52ee6b8

Browse files
authored
Merge pull request #229 from isuruf/index
Add indexing matrix with an array
2 parents 72a54f3 + 691952a commit 52ee6b8

File tree

2 files changed

+105
-26
lines changed

2 files changed

+105
-26
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,7 +3154,28 @@ cdef class DenseMatrixBase(MatrixBase):
31543154
return [self.get(i // self.ncols(), i % self.ncols()) for i in range(*item.indices(len(self)))]
31553155
elif isinstance(item, int):
31563156
return self.get(item // self.ncols(), item % self.ncols())
3157-
elif isinstance(item, tuple):
3157+
elif isinstance(item, tuple) and len(item) == 2:
3158+
if is_sequence(item[0]) or is_sequence(item[1]):
3159+
if isinstance(item[0], slice):
3160+
row_iter = range(*item[0].indices(self.rows))
3161+
elif is_sequence(item[0]):
3162+
row_iter = item[0]
3163+
else:
3164+
row_iter = [item[0]]
3165+
3166+
if isinstance(item[1], slice):
3167+
col_iter = range(*item[1].indices(self.rows))
3168+
elif is_sequence(item[1]):
3169+
col_iter = item[1]
3170+
else:
3171+
col_iter = [item[1]]
3172+
3173+
v = []
3174+
for row in row_iter:
3175+
for col in col_iter:
3176+
v.append(self.get(row, col))
3177+
return self.__class__(len(row_iter), len(col_iter), v)
3178+
31583179
if isinstance(item[0], int) and isinstance(item[1], int):
31593180
return self.get(item[0], item[1])
31603181
else:
@@ -3169,6 +3190,8 @@ cdef class DenseMatrixBase(MatrixBase):
31693190
if (s[1] < 0 or s[1] > self.cols or s[1] >= s[3] or s[3] < 0 or s[3] > self.cols):
31703191
raise IndexError
31713192
return self._submatrix(*s)
3193+
elif is_sequence(item):
3194+
return [self.get(ind // self.ncols(), ind % self.ncols()) for ind in item]
31723195
else:
31733196
raise NotImplementedError
31743197

@@ -3181,32 +3204,48 @@ cdef class DenseMatrixBase(MatrixBase):
31813204
for i in range(*key.indices(len(self))):
31823205
self.set(i // self.ncols(), i % self.ncols(), value[k])
31833206
k = k + 1
3184-
elif isinstance(key, tuple):
3185-
if isinstance(key[0], int):
3186-
if isinstance(key[1], int):
3187-
self.set(key[0], key[1], value)
3188-
else:
3189-
k = 0
3190-
for i in range(*key[1].indices(self.cols)):
3191-
self.set(key[0], i, value[k])
3192-
k = k + 1
3207+
elif isinstance(key, tuple) and len(key) == 2:
3208+
if isinstance(key[0], slice):
3209+
row_iter = range(*key[0].indices(self.rows))
3210+
elif is_sequence(key[0]):
3211+
row_iter = key[0]
31933212
else:
3194-
if isinstance(key[1], int):
3195-
k = 0
3196-
for i in range(*key[0].indices(self.rows)):
3197-
self.set(i, key[1], value[k])
3198-
k = k + 1
3199-
else:
3200-
k = 0
3201-
for i in range(*key[0].indices(self.rows)):
3202-
l = 0
3203-
for j in range(*key[1].indices(self.cols)):
3204-
try:
3205-
self.set(i, j, value[k, l])
3206-
except TypeError:
3207-
self.set(i, j, value[k][l])
3208-
l = l + 1
3209-
k = k + 1
3213+
row_iter = [key[0]]
3214+
3215+
if isinstance(key[1], slice):
3216+
col_iter = range(*key[1].indices(self.rows))
3217+
elif is_sequence(key[1]):
3218+
col_iter = key[1]
3219+
else:
3220+
col_iter = [key[1]]
3221+
3222+
for r, row in enumerate(row_iter):
3223+
for c, col in enumerate(col_iter):
3224+
if not is_sequence(value):
3225+
self.set(row, col, value)
3226+
continue
3227+
try:
3228+
self.set(row, col, value[r, c])
3229+
continue
3230+
except TypeError:
3231+
pass
3232+
try:
3233+
self.set(row, col, value[r][c])
3234+
continue
3235+
except TypeError:
3236+
pass
3237+
3238+
if len(row_iter) == 1:
3239+
self.set(row, col, value[c])
3240+
continue
3241+
3242+
if len(col_iter) == 1:
3243+
self.set(row, col, value[r])
3244+
continue
3245+
3246+
elif is_sequence(key) and is_sequence(value):
3247+
for val, ind in zip(value, key):
3248+
self.set(ind // self.ncols(), ind % self.ncols(), val)
32103249
else:
32113250
raise NotImplementedError
32123251

symengine/tests/test_matrices.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def test_get_item():
6262
assert A[1:, :3] == DenseMatrix(2, 3, [4, 5, 6, 7, 8, 9])
6363
assert A[1:] == [2, 3, 4, 5, 6, 7, 8, 9]
6464
assert A[-2:] == [8, 9]
65+
assert A[[0, 2], 0] == DenseMatrix(2, 1, [1, 7])
66+
assert A[[0, 2], [0]] == DenseMatrix(2, 1, [1, 7])
67+
assert A[0, [0, 2]] == DenseMatrix(1, 2, [1, 3])
68+
assert A[[0], [0, 2]] == DenseMatrix(1, 2, [1, 3])
6569

6670
raises(IndexError, lambda: A[-10])
6771
raises(IndexError, lambda: A[9])
@@ -89,6 +93,42 @@ def test_set_item():
8993
A[0::2, :] = [[1, 2, 3], [4, 5, 6]]
9094
assert A == DenseMatrix(3, 3, [1, 2, 3, 4, 14, 1, 4, 5, 6])
9195

96+
B = DenseMatrix(A)
97+
B[[0, 2], 0] = -1
98+
assert B == DenseMatrix(3, 3, [-1, 2, 3, 4, 14, 1, -1, 5, 6])
99+
100+
B = DenseMatrix(A)
101+
B[[0, 2], 0] = [-1, -2]
102+
assert B == DenseMatrix(3, 3, [-1, 2, 3, 4, 14, 1, -2, 5, 6])
103+
104+
B = DenseMatrix(A)
105+
B[[0, 2], 0] = [[-1], [-2]]
106+
assert B == DenseMatrix(3, 3, [-1, 2, 3, 4, 14, 1, -2, 5, 6])
107+
108+
B = DenseMatrix(A)
109+
B[[0, 2], [0]] = [-1, -2]
110+
assert B == DenseMatrix(3, 3, [-1, 2, 3, 4, 14, 1, -2, 5, 6])
111+
112+
B = DenseMatrix(A)
113+
B[[0, 2], [0]] = [[-1], [-2]]
114+
assert B == DenseMatrix(3, 3, [-1, 2, 3, 4, 14, 1, -2, 5, 6])
115+
116+
B = DenseMatrix(A)
117+
B[0, [0, 2]] = [-1, -2]
118+
assert B == DenseMatrix(3, 3, [-1, 2, -2, 4, 14, 1, 4, 5, 6])
119+
120+
B = DenseMatrix(A)
121+
B[0, [0, 2]] = -1
122+
assert B == DenseMatrix(3, 3, [-1, 2, -1, 4, 14, 1, 4, 5, 6])
123+
124+
B = DenseMatrix(A)
125+
B[:, [0, 2]] = -1
126+
assert B == DenseMatrix(3, 3, [-1, 2, -1, -1, 14, -1, -1, 5, -1])
127+
128+
B = DenseMatrix(A)
129+
B[[0, 1], [0, 2]] = -1
130+
assert B == DenseMatrix(3, 3, [-1, 2, -1, -1, 14, -1, 4, 5, 6])
131+
92132

93133
def test_set():
94134
i7 = Integer(7)

0 commit comments

Comments
 (0)