Skip to content

Commit 3cc141f

Browse files
committed
reverse iterate
1 parent 719e3fa commit 3cc141f

File tree

1 file changed

+55
-54
lines changed

1 file changed

+55
-54
lines changed

src/lib.rs

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ use pyo3::ffi::{self, Py_ssize_t, PyDictObject, PyObject, PyTypeObject};
66
use pyo3::types::{PyList, PyModule, PyType};
77
use pyo3::{PyTypeInfo, prelude::*};
88

9-
pub struct DictValuesIter {
9+
pub struct ReverseDictValuesIter {
1010
entries: *const pydict::PyDictUnicodeEntry,
1111
current: usize,
12-
end: usize,
1312
}
1413

15-
impl DictValuesIter {
14+
impl ReverseDictValuesIter {
1615
/// Creates a new iterator over dictionary values
1716
///
1817
/// # Safety
@@ -30,21 +29,20 @@ impl DictValuesIter {
3029

3130
Self {
3231
entries,
33-
current: 0,
34-
end: n,
32+
current: n,
3533
}
3634
}
3735
}
3836
}
3937

40-
impl Iterator for DictValuesIter {
38+
impl Iterator for ReverseDictValuesIter {
4139
type Item = *mut PyObject;
4240

4341
fn next(&mut self) -> Option<Self::Item> {
4442
// Skip null entries until we find a valid one or reach the end
45-
while self.current < self.end {
46-
let entry = &unsafe { *self.entries.add(self.current) };
47-
self.current += 1;
43+
while self.current > 0 {
44+
self.current -= 1;
45+
let entry: &pydict::PyDictUnicodeEntry = &unsafe { *self.entries.add(self.current) };
4846

4947
if !entry.me_value.is_null() {
5048
return Some(entry.me_value);
@@ -71,27 +69,23 @@ fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> {
7169
}
7270
}
7371

74-
unsafe fn is_subtype(
75-
subtype: *mut pyo3::ffi::PyTypeObject,
76-
base: *mut pyo3::ffi::PyTypeObject,
72+
fn isinstance_of_ast(
73+
obj: *mut PyObject,
74+
base_ast_and_expr_type: (*mut PyTypeObject, *mut PyTypeObject),
7775
) -> bool {
78-
// Walk up the inheritance chain via tp_base, max 3 jumps
79-
let mut current = subtype;
80-
for _ in 0..3 {
81-
current = unsafe { (*current).tp_base };
82-
if current.is_null() {
83-
return false;
84-
}
85-
if current == base {
86-
return true;
87-
}
76+
let subtype = unsafe { ffi::Py_TYPE(obj) };
77+
let first_supertype = unsafe { (*subtype).tp_base };
78+
if first_supertype.is_null() {
79+
return false;
80+
}
81+
let (base_ast_type, base_expr_type) = base_ast_and_expr_type;
82+
if first_supertype == base_ast_type {
83+
return true;
8884
}
8985

90-
false
91-
}
92-
93-
fn isinstance_of_ast(obj: *mut PyObject, base_ast_type: &*mut PyTypeObject) -> bool {
94-
unsafe { is_subtype(ffi::Py_TYPE(obj), *base_ast_type) }
86+
let second_supertype = unsafe { (*first_supertype).tp_base };
87+
88+
second_supertype == base_ast_type || second_supertype == base_expr_type
9589
}
9690

9791
fn is_list(obj: *mut PyObject, py_list_type: *mut PyTypeObject) -> bool {
@@ -106,27 +100,32 @@ fn get_item_of_list(obj: *mut PyObject, index: Py_ssize_t) -> *mut PyObject {
106100
unsafe { ffi::PyList_GET_ITEM(obj, index) }
107101
}
108102

109-
fn walk_node(
103+
fn walk_node_iterative(
110104
node: *mut PyObject,
111-
base_ast_type: &*mut PyTypeObject,
105+
base_ast_and_expr_type: (*mut PyTypeObject, *mut PyTypeObject),
112106
py_list_type: *mut PyTypeObject,
113107
result_list: &mut Vec<*mut PyObject>,
114108
) -> PyResult<()> {
115-
result_list.push(node);
116-
117-
// Recursively walk through child nodes
118-
let Some(dict) = get_instance_dict_fast(node) else {
119-
return Ok(());
120-
};
121-
for item_ptr in unsafe { DictValuesIter::new(dict.cast::<PyDictObject>()) } {
122-
if isinstance_of_ast(item_ptr, base_ast_type) {
123-
walk_node(item_ptr, base_ast_type, py_list_type, result_list)?;
124-
} else if is_list(item_ptr, py_list_type) {
125-
let length = get_length_of_list(item_ptr);
126-
for i in 0..length {
127-
let item_ptr = get_item_of_list(item_ptr, i);
128-
if isinstance_of_ast(item_ptr, base_ast_type) {
129-
walk_node(item_ptr, base_ast_type, py_list_type, result_list)?;
109+
let mut stack = vec![node];
110+
111+
while let Some(current_node) = stack.pop() {
112+
result_list.push(current_node);
113+
114+
// Walk through child nodes
115+
let Some(dict) = get_instance_dict_fast(current_node) else {
116+
continue;
117+
};
118+
119+
for item_ptr in unsafe { ReverseDictValuesIter::new(dict.cast::<PyDictObject>()) } {
120+
if isinstance_of_ast(item_ptr, base_ast_and_expr_type) {
121+
stack.push(item_ptr);
122+
} else if is_list(item_ptr, py_list_type) {
123+
let length = get_length_of_list(item_ptr);
124+
for i in (0..length).rev() {
125+
let item_ptr = get_item_of_list(item_ptr, i);
126+
if isinstance_of_ast(item_ptr, base_ast_and_expr_type) {
127+
stack.push(item_ptr);
128+
}
130129
}
131130
}
132131
}
@@ -136,36 +135,38 @@ fn walk_node(
136135
}
137136

138137
thread_local! {
139-
static BASE_AST_TYPE: RefCell<Option<*mut PyTypeObject>> = const { RefCell::new(None) };
138+
static BASE_AST_TYPE_AND_EXPR: RefCell<Option<(*mut PyTypeObject, *mut PyTypeObject)>> = const { RefCell::new(None) };
140139
}
141140

142141
#[inline(never)]
143-
fn get_base_ast_type<'py>(py: Python<'py>) -> PyResult<*mut PyTypeObject> {
142+
fn get_base_ast_type<'py>(py: Python<'py>) -> PyResult<(*mut PyTypeObject, *mut PyTypeObject)> {
144143
let ast_module = py.import("ast")?;
145144
let ast_class = ast_module.getattr("AST")?.cast_into::<PyType>()?;
145+
let expr_class = ast_module.getattr("expr")?.cast_into::<PyType>()?;
146146

147-
Ok(ast_class.as_type_ptr())
147+
Ok((ast_class.as_type_ptr(), expr_class.as_type_ptr()))
148148
}
149149

150150
#[pyfunction]
151151
fn walk<'py>(py: Python, node: Bound<'py, PyAny>) -> PyResult<Py<PyList>> {
152152
let mut result_list = Vec::new();
153153

154154
// Initialize if needed (separate step with mutable borrow)
155-
BASE_AST_TYPE.with(|cache| {
155+
BASE_AST_TYPE_AND_EXPR.with(|cache| {
156156
if cache.borrow().is_none() {
157157
*cache.borrow_mut() = Some(get_base_ast_type(py)?);
158158
}
159159
Ok::<(), PyErr>(())
160160
})?;
161161

162162
// Now use immutable borrow for the actual work
163-
BASE_AST_TYPE.with(|cache| {
163+
BASE_AST_TYPE_AND_EXPR.with(|cache| {
164164
let cache_ref = cache.borrow();
165-
let base_ast_type = cache_ref.as_ref().unwrap();
166-
walk_node(
165+
let base_ast_and_expr_type = cache_ref.as_ref().unwrap();
166+
167+
walk_node_iterative(
167168
node.as_ptr(),
168-
base_ast_type,
169+
*base_ast_and_expr_type,
169170
PyList::type_object_raw(py),
170171
&mut result_list,
171172
)?;
@@ -200,7 +201,7 @@ mod tests {
200201
Python::attach(|py| {
201202
let dict = PyDict::new(py);
202203
let dict_ptr = dict.as_ptr() as *mut pyo3::ffi::PyDictObject;
203-
let values = unsafe { DictValuesIter::new(dict_ptr) }.collect::<Vec<_>>();
204+
let values = unsafe { ReverseDictValuesIter::new(dict_ptr) }.collect::<Vec<_>>();
204205
assert_eq!(values.len(), 0);
205206
});
206207
}
@@ -216,7 +217,7 @@ mod tests {
216217
dict.set_item("c", 3).unwrap();
217218

218219
let dict_ptr = dict.as_ptr() as *mut pyo3::ffi::PyDictObject;
219-
let values = unsafe { DictValuesIter::new(dict_ptr) }.collect::<Vec<_>>();
220+
let values = unsafe { ReverseDictValuesIter::new(dict_ptr) }.collect::<Vec<_>>();
220221
assert_eq!(values.len(), 3);
221222
});
222223
}

0 commit comments

Comments
 (0)