Skip to content

Commit 6837ecb

Browse files
committed
[vector] Supported list/tuple indexing in vector slicing
1 parent 7acf581 commit 6837ecb

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

include/eigenpy/std-vector.hpp

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ struct overload_base_get_item_for_std_vector
8686
template <class Class>
8787
void visit(Class &cl) const {
8888
cl.def("__getitem__", &base_get_item_int)
89-
.def("__getitem__", &base_get_item_slice);
89+
.def("__getitem__", &base_get_item_slice)
90+
.def("__getitem__", &base_get_item_list)
91+
.def("__getitem__", &base_get_item_tuple);
9092
}
9193

9294
private:
@@ -109,9 +111,6 @@ struct overload_base_get_item_for_std_vector
109111
static boost::python::object base_get_item_slice(
110112
boost::python::back_reference<Container&> container,
111113
boost::python::slice slice) {
112-
113-
namespace bp = boost::python;
114-
115114
bp::list out;
116115
try {
117116
auto rng = slice.get_indices(container.get().begin(), container.get().end());
@@ -131,6 +130,54 @@ struct overload_base_get_item_for_std_vector
131130
return out;
132131
}
133132

133+
static bp::object base_get_item_list(bp::back_reference<Container&> c, bp::list idxs) {
134+
const Py_ssize_t m = bp::len(idxs);
135+
bp::list out;
136+
for (Py_ssize_t k = 0; k < m; ++k) {
137+
bp::object obj = idxs[k];
138+
bp::extract<long> ei(obj);
139+
if (!ei.check()) {
140+
PyErr_SetString(PyExc_TypeError, "indices must be integers");
141+
bp::throw_error_already_set();
142+
}
143+
auto idx = normalize_index(c.get().size(), ei());
144+
out.append(elem_ref(c.get(), idx));
145+
}
146+
return out;
147+
}
148+
149+
static bp::object base_get_item_tuple(bp::back_reference<Container&> c, bp::tuple idxs) {
150+
const Py_ssize_t m = bp::len(idxs);
151+
bp::list out;
152+
for (Py_ssize_t k = 0; k < m; ++k) {
153+
bp::object obj = idxs[k];
154+
bp::extract<long> ei(obj);
155+
if (!ei.check()) {
156+
PyErr_SetString(PyExc_TypeError, "indices must be integers");
157+
bp::throw_error_already_set();
158+
}
159+
auto idx = normalize_index(c.get().size(), ei());
160+
out.append(elem_ref(c.get(), idx));
161+
}
162+
return out;
163+
}
164+
165+
static index_type normalize_index(std::size_t n, long i) {
166+
long idx = i;
167+
if (idx < 0) idx += static_cast<long>(n);
168+
if (idx < 0 || idx >= static_cast<long>(n)) {
169+
PyErr_SetString(PyExc_IndexError, "index out of range");
170+
bp::throw_error_already_set();
171+
}
172+
return static_cast<index_type>(idx);
173+
}
174+
175+
static bp::object elem_ref(Container& c, index_type i) {
176+
typename bp::to_python_indirect<value_type&,
177+
bp::detail::make_reference_holder> conv;
178+
return bp::object(bp::handle<>(conv(c[i])));
179+
}
180+
134181
static index_type convert_index(Container &container, PyObject *i_) {
135182
bp::extract<long> i(i_);
136183
if (i.check()) {

unittest/python/test_std_vector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,11 @@ def checkZero(v):
107107
checkAllValues(l6[1:], l6.tolist()[1:])
108108
checkAllValues(l6[:-1], l6.tolist()[:-1])
109109
checkAllValues(l6[::2], l6.tolist()[::2])
110+
L = [0, 2]
111+
L6_copy = l6[L]
112+
for k, i in enumerate(L):
113+
checkAllValues(L6_copy[k], l6[i])
114+
T = (0, 2)
115+
L6_copy = l6[T]
116+
for k, i in enumerate(L):
117+
checkAllValues(L6_copy[k], l6[i])

0 commit comments

Comments
 (0)