@@ -6,13 +6,12 @@ use pyo3::ffi::{self, Py_ssize_t, PyDictObject, PyObject, PyTypeObject};
66use pyo3:: types:: { PyList , PyModule , PyType } ;
77use pyo3:: { PyTypeInfo , prelude:: * } ;
88
9- pub struct DictValuesIter {
9+ pub struct ReverseDictValuesIter {
1010 entries : * const pydict:: PyDictUnicodeEntry ,
1111 current : usize ,
12- end : usize ,
1312}
1413
15- impl DictValuesIter {
14+ impl ReverseDictValuesIter {
1615 /// Creates a new iterator over dictionary values
1716 ///
1817 /// # Safety
@@ -30,21 +29,20 @@ impl DictValuesIter {
3029
3130 Self {
3231 entries,
33- current : 0 ,
34- end : n,
32+ current : n,
3533 }
3634 }
3735 }
3836}
3937
40- impl Iterator for DictValuesIter {
38+ impl Iterator for ReverseDictValuesIter {
4139 type Item = * mut PyObject ;
4240
4341 fn next ( & mut self ) -> Option < Self :: Item > {
4442 // Skip null entries until we find a valid one or reach the end
45- while self . current < self . end {
46- let entry = & unsafe { * self . entries . add ( self . current ) } ;
47- self . current += 1 ;
43+ while self . current > 0 {
44+ self . current -= 1 ;
45+ let entry : & pydict :: PyDictUnicodeEntry = & unsafe { * self . entries . add ( self . current ) } ;
4846
4947 if !entry. me_value . is_null ( ) {
5048 return Some ( entry. me_value ) ;
@@ -71,27 +69,23 @@ fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> {
7169 }
7270}
7371
74- unsafe fn is_subtype (
75- subtype : * mut pyo3 :: ffi :: PyTypeObject ,
76- base : * mut pyo3 :: ffi :: PyTypeObject ,
72+ fn isinstance_of_ast (
73+ obj : * mut PyObject ,
74+ base_ast_and_expr_type : ( * mut PyTypeObject , * mut PyTypeObject ) ,
7775) -> bool {
78- // Walk up the inheritance chain via tp_base, max 3 jumps
79- let mut current = subtype;
80- for _ in 0 ..3 {
81- current = unsafe { ( * current) . tp_base } ;
82- if current. is_null ( ) {
83- return false ;
84- }
85- if current == base {
86- return true ;
87- }
76+ let subtype = unsafe { ffi:: Py_TYPE ( obj) } ;
77+ let first_supertype = unsafe { ( * subtype) . tp_base } ;
78+ if first_supertype. is_null ( ) {
79+ return false ;
80+ }
81+ let ( base_ast_type, base_expr_type) = base_ast_and_expr_type;
82+ if first_supertype == base_ast_type {
83+ return true ;
8884 }
8985
90- false
91- }
92-
93- fn isinstance_of_ast ( obj : * mut PyObject , base_ast_type : & * mut PyTypeObject ) -> bool {
94- unsafe { is_subtype ( ffi:: Py_TYPE ( obj) , * base_ast_type) }
86+ let second_supertype = unsafe { ( * first_supertype) . tp_base } ;
87+
88+ second_supertype == base_ast_type || second_supertype == base_expr_type
9589}
9690
9791fn is_list ( obj : * mut PyObject , py_list_type : * mut PyTypeObject ) -> bool {
@@ -106,27 +100,32 @@ fn get_item_of_list(obj: *mut PyObject, index: Py_ssize_t) -> *mut PyObject {
106100 unsafe { ffi:: PyList_GET_ITEM ( obj, index) }
107101}
108102
109- fn walk_node (
103+ fn walk_node_iterative (
110104 node : * mut PyObject ,
111- base_ast_type : & * mut PyTypeObject ,
105+ base_ast_and_expr_type : ( * mut PyTypeObject , * mut PyTypeObject ) ,
112106 py_list_type : * mut PyTypeObject ,
113107 result_list : & mut Vec < * mut PyObject > ,
114108) -> PyResult < ( ) > {
115- result_list. push ( node) ;
116-
117- // Recursively walk through child nodes
118- let Some ( dict) = get_instance_dict_fast ( node) else {
119- return Ok ( ( ) ) ;
120- } ;
121- for item_ptr in unsafe { DictValuesIter :: new ( dict. cast :: < PyDictObject > ( ) ) } {
122- if isinstance_of_ast ( item_ptr, base_ast_type) {
123- walk_node ( item_ptr, base_ast_type, py_list_type, result_list) ?;
124- } else if is_list ( item_ptr, py_list_type) {
125- let length = get_length_of_list ( item_ptr) ;
126- for i in 0 ..length {
127- let item_ptr = get_item_of_list ( item_ptr, i) ;
128- if isinstance_of_ast ( item_ptr, base_ast_type) {
129- walk_node ( item_ptr, base_ast_type, py_list_type, result_list) ?;
109+ let mut stack = vec ! [ node] ;
110+
111+ while let Some ( current_node) = stack. pop ( ) {
112+ result_list. push ( current_node) ;
113+
114+ // Walk through child nodes
115+ let Some ( dict) = get_instance_dict_fast ( current_node) else {
116+ continue ;
117+ } ;
118+
119+ for item_ptr in unsafe { ReverseDictValuesIter :: new ( dict. cast :: < PyDictObject > ( ) ) } {
120+ if isinstance_of_ast ( item_ptr, base_ast_and_expr_type) {
121+ stack. push ( item_ptr) ;
122+ } else if is_list ( item_ptr, py_list_type) {
123+ let length = get_length_of_list ( item_ptr) ;
124+ for i in ( 0 ..length) . rev ( ) {
125+ let item_ptr = get_item_of_list ( item_ptr, i) ;
126+ if isinstance_of_ast ( item_ptr, base_ast_and_expr_type) {
127+ stack. push ( item_ptr) ;
128+ }
130129 }
131130 }
132131 }
@@ -136,36 +135,38 @@ fn walk_node(
136135}
137136
138137thread_local ! {
139- static BASE_AST_TYPE : RefCell <Option <* mut PyTypeObject >> = const { RefCell :: new( None ) } ;
138+ static BASE_AST_TYPE_AND_EXPR : RefCell <Option <( * mut PyTypeObject , * mut PyTypeObject ) >> = const { RefCell :: new( None ) } ;
140139}
141140
142141#[ inline( never) ]
143- fn get_base_ast_type < ' py > ( py : Python < ' py > ) -> PyResult < * mut PyTypeObject > {
142+ fn get_base_ast_type < ' py > ( py : Python < ' py > ) -> PyResult < ( * mut PyTypeObject , * mut PyTypeObject ) > {
144143 let ast_module = py. import ( "ast" ) ?;
145144 let ast_class = ast_module. getattr ( "AST" ) ?. cast_into :: < PyType > ( ) ?;
145+ let expr_class = ast_module. getattr ( "expr" ) ?. cast_into :: < PyType > ( ) ?;
146146
147- Ok ( ast_class. as_type_ptr ( ) )
147+ Ok ( ( ast_class. as_type_ptr ( ) , expr_class . as_type_ptr ( ) ) )
148148}
149149
150150#[ pyfunction]
151151fn walk < ' py > ( py : Python , node : Bound < ' py , PyAny > ) -> PyResult < Py < PyList > > {
152152 let mut result_list = Vec :: new ( ) ;
153153
154154 // Initialize if needed (separate step with mutable borrow)
155- BASE_AST_TYPE . with ( |cache| {
155+ BASE_AST_TYPE_AND_EXPR . with ( |cache| {
156156 if cache. borrow ( ) . is_none ( ) {
157157 * cache. borrow_mut ( ) = Some ( get_base_ast_type ( py) ?) ;
158158 }
159159 Ok :: < ( ) , PyErr > ( ( ) )
160160 } ) ?;
161161
162162 // Now use immutable borrow for the actual work
163- BASE_AST_TYPE . with ( |cache| {
163+ BASE_AST_TYPE_AND_EXPR . with ( |cache| {
164164 let cache_ref = cache. borrow ( ) ;
165- let base_ast_type = cache_ref. as_ref ( ) . unwrap ( ) ;
166- walk_node (
165+ let base_ast_and_expr_type = cache_ref. as_ref ( ) . unwrap ( ) ;
166+
167+ walk_node_iterative (
167168 node. as_ptr ( ) ,
168- base_ast_type ,
169+ * base_ast_and_expr_type ,
169170 PyList :: type_object_raw ( py) ,
170171 & mut result_list,
171172 ) ?;
@@ -200,7 +201,7 @@ mod tests {
200201 Python :: attach ( |py| {
201202 let dict = PyDict :: new ( py) ;
202203 let dict_ptr = dict. as_ptr ( ) as * mut pyo3:: ffi:: PyDictObject ;
203- let values = unsafe { DictValuesIter :: new ( dict_ptr) } . collect :: < Vec < _ > > ( ) ;
204+ let values = unsafe { ReverseDictValuesIter :: new ( dict_ptr) } . collect :: < Vec < _ > > ( ) ;
204205 assert_eq ! ( values. len( ) , 0 ) ;
205206 } ) ;
206207 }
@@ -216,7 +217,7 @@ mod tests {
216217 dict. set_item ( "c" , 3 ) . unwrap ( ) ;
217218
218219 let dict_ptr = dict. as_ptr ( ) as * mut pyo3:: ffi:: PyDictObject ;
219- let values = unsafe { DictValuesIter :: new ( dict_ptr) } . collect :: < Vec < _ > > ( ) ;
220+ let values = unsafe { ReverseDictValuesIter :: new ( dict_ptr) } . collect :: < Vec < _ > > ( ) ;
220221 assert_eq ! ( values. len( ) , 3 ) ;
221222 } ) ;
222223 }
0 commit comments