1+ mod pydict;
2+
13use std:: cell:: RefCell ;
24
3- use fastset:: Set ;
4- use pyo3:: ffi:: { self , Py_ssize_t , PyObject , PyTypeObject } ;
5- use pyo3:: types:: { PyList , PyModule , PyString , PyType } ;
5+ use pyo3:: ffi:: { self , Py_ssize_t , PyDictObject , PyObject , PyTypeObject } ;
6+ use pyo3:: types:: { PyList , PyModule , PyType } ;
67use pyo3:: { PyTypeInfo , prelude:: * } ;
78
8- pub struct BorrowedDictIter {
9- dict : * mut PyObject ,
10- ppos : ffi :: Py_ssize_t ,
11- len : ffi :: Py_ssize_t ,
9+ pub struct DictValuesIter {
10+ entries : * const pydict :: PyDictUnicodeEntry ,
11+ current : usize ,
12+ end : usize ,
1213}
1314
14- impl Iterator for BorrowedDictIter {
15- type Item = * mut PyObject ;
16-
17- #[ inline]
18- fn next ( & mut self ) -> Option < Self :: Item > {
19- let mut value: * mut PyObject = std:: ptr:: null_mut ( ) ;
20-
21- // Safety: self.dict lives sufficiently long that the pointer is not dangling
22- if unsafe { ffi:: PyDict_Next ( self . dict , & mut self . ppos , std:: ptr:: null_mut ( ) , & mut value) }
23- != 0
24- {
25- self . len -= 1 ;
26- Some ( value)
27- } else {
28- None
15+ impl DictValuesIter {
16+ /// Creates a new iterator over dictionary values
17+ ///
18+ /// # Safety
19+ ///
20+ /// The caller must ensure that:
21+ /// - `obj` is a valid pointer to a `PyDictObject`
22+ /// - The dictionary remains valid for the lifetime of the iterator
23+ /// - The dictionary is not modified while iterating
24+ pub unsafe fn new ( obj : * mut PyDictObject ) -> Self {
25+ unsafe {
26+ let dict = & * obj;
27+ let keys = & * dict. ma_keys . cast :: < pydict:: PyDictKeysObject > ( ) ;
28+ let entries = keys. unicode_entries ( ) ;
29+ let n = keys. dk_nentries as usize ;
30+
31+ Self {
32+ entries,
33+ current : 0 ,
34+ end : n,
35+ }
2936 }
3037 }
31-
32- #[ inline]
33- fn size_hint ( & self ) -> ( usize , Option < usize > ) {
34- let len = self . len ( ) ;
35- ( len, Some ( len) )
36- }
37-
38- #[ inline]
39- fn count ( self ) -> usize
40- where
41- Self : Sized ,
42- {
43- self . len ( )
44- }
4538}
4639
47- impl ExactSizeIterator for BorrowedDictIter {
48- fn len ( & self ) -> usize {
49- self . len as usize
50- }
51- }
40+ impl Iterator for DictValuesIter {
41+ type Item = * mut PyObject ;
5242
53- fn dict_len ( dict : * mut PyObject ) -> Py_ssize_t {
54- unsafe { ffi:: PyDict_Size ( dict) }
55- }
43+ fn next ( & mut self ) -> Option < Self :: Item > {
44+ // Skip null entries until we find a valid one or reach the end
45+ while self . current < self . end {
46+ let entry = & unsafe { * self . entries . add ( self . current ) } ;
47+ self . current += 1 ;
48+
49+ if !entry. me_value . is_null ( ) {
50+ return Some ( entry. me_value ) ;
51+ }
52+ }
5653
57- impl BorrowedDictIter {
58- pub fn new ( dict : * mut PyObject ) -> Self {
59- let len = dict_len ( dict) ;
60- BorrowedDictIter { dict, ppos : 0 , len }
54+ None
6155 }
6256}
6357
@@ -66,8 +60,7 @@ fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> {
6660 let dict_ptr = ( * obj) . ob_type . as_ref ( ) ?. tp_dictoffset ;
6761
6862 if dict_ptr != 0 {
69- let dict_ptr_addr =
70- ( obj as * mut u8 ) . offset ( dict_ptr as isize ) as * mut * mut ffi:: PyObject ;
63+ let dict_ptr_addr = ( obj as * mut u8 ) . offset ( dict_ptr) as * mut * mut ffi:: PyObject ;
7164 let dict = * dict_ptr_addr;
7265
7366 if !dict. is_null ( ) {
@@ -78,12 +71,35 @@ fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> {
7871 }
7972}
8073
81- fn isinstance_of_ast ( obj : * mut PyObject , all_ast_classes : & Set ) -> bool {
82- let el = unsafe { ffi:: Py_TYPE ( obj) } ;
83- all_ast_classes. contains ( & ( el as usize ) )
74+ unsafe fn is_subtype (
75+ subtype : * mut pyo3:: ffi:: PyTypeObject ,
76+ base : * mut pyo3:: ffi:: PyTypeObject ,
77+ ) -> bool {
78+ // If they're the same type, it's trivially a subtype
79+ if subtype == base {
80+ return true ;
81+ }
82+
83+ // Walk up the inheritance chain via tp_base, max 3 jumps
84+ let mut current = subtype;
85+ for _ in 0 ..3 {
86+ current = unsafe { ( * current) . tp_base } ;
87+ if current. is_null ( ) {
88+ return false ;
89+ }
90+ if current == base {
91+ return true ;
92+ }
93+ }
94+
95+ false
96+ }
97+
98+ fn isinstance_of_ast ( obj : * mut PyObject , base_ast_type : & * mut PyTypeObject ) -> bool {
99+ unsafe { is_subtype ( ffi:: Py_TYPE ( obj) , * base_ast_type) }
84100}
85101
86- fn isinstance_of_list ( obj : * mut PyObject , py_list_type : * mut PyTypeObject ) -> bool {
102+ fn is_list ( obj : * mut PyObject , py_list_type : * mut PyTypeObject ) -> bool {
87103 unsafe { ffi:: Py_TYPE ( obj) == py_list_type }
88104}
89105
@@ -95,9 +111,9 @@ fn get_item_of_list(obj: *mut PyObject, index: Py_ssize_t) -> *mut PyObject {
95111 unsafe { ffi:: PyList_GET_ITEM ( obj, index) }
96112}
97113
98- fn walk_node < ' py > (
114+ fn walk_node (
99115 node : * mut PyObject ,
100- all_ast_classes : & Set ,
116+ base_ast_type : & * mut PyTypeObject ,
101117 py_list_type : * mut PyTypeObject ,
102118 result_list : & mut Vec < * mut PyObject > ,
103119) -> PyResult < ( ) > {
@@ -107,16 +123,15 @@ fn walk_node<'py>(
107123 let Some ( dict) = get_instance_dict_fast ( node) else {
108124 return Ok ( ( ) ) ;
109125 } ;
110- let values = BorrowedDictIter :: new ( dict) ;
111- for item_ptr in values {
112- if isinstance_of_ast ( item_ptr, all_ast_classes) {
113- walk_node ( item_ptr, all_ast_classes, py_list_type, result_list) ?;
114- } else if isinstance_of_list ( item_ptr, py_list_type) {
126+ for item_ptr in unsafe { DictValuesIter :: new ( dict. cast :: < PyDictObject > ( ) ) } {
127+ if isinstance_of_ast ( item_ptr, base_ast_type) {
128+ walk_node ( item_ptr, base_ast_type, py_list_type, result_list) ?;
129+ } else if is_list ( item_ptr, py_list_type) {
115130 let length = get_length_of_list ( item_ptr) ;
116131 for i in 0 ..length {
117132 let item_ptr = get_item_of_list ( item_ptr, i) ;
118- if isinstance_of_ast ( item_ptr, all_ast_classes ) {
119- walk_node ( item_ptr, all_ast_classes , py_list_type, result_list) ?;
133+ if isinstance_of_ast ( item_ptr, base_ast_type ) {
134+ walk_node ( item_ptr, base_ast_type , py_list_type, result_list) ?;
120135 }
121136 }
122137 }
@@ -126,46 +141,36 @@ fn walk_node<'py>(
126141}
127142
128143thread_local ! {
129- static AST_CLASSES : RefCell <Option <Set >> = RefCell :: new( None ) ;
144+ static BASE_AST_TYPE : RefCell <Option <* mut PyTypeObject >> = const { RefCell :: new( None ) } ;
130145}
131146
132147#[ inline( never) ]
133- fn compute_ast_classes < ' py > ( py : Python < ' py > ) -> PyResult < Set > {
134- let mut classes = Vec :: new ( ) ;
148+ fn get_base_ast_type < ' py > ( py : Python < ' py > ) -> PyResult < * mut PyTypeObject > {
135149 let ast_module = py. import ( "ast" ) ?;
136150 let ast_class = ast_module. getattr ( "AST" ) ?. cast_into :: < PyType > ( ) ?;
137151
138- for field in ast_module. dir ( ) ? {
139- let class = ast_module. getattr ( field. cast_exact :: < PyString > ( ) ?) ?;
140- if let Ok ( class) = class. cast_into :: < PyType > ( ) {
141- if class. is_subclass ( ast_class. cast :: < PyType > ( ) ?) ? {
142- classes. push ( class. as_type_ptr ( ) as usize ) ;
143- }
144- }
145- }
146-
147- Ok ( Set :: from ( classes) )
152+ Ok ( ast_class. as_type_ptr ( ) )
148153}
149154
150155#[ pyfunction]
151156fn walk < ' py > ( py : Python , node : Bound < ' py , PyAny > ) -> PyResult < Py < PyList > > {
152157 let mut result_list = Vec :: new ( ) ;
153158
154159 // Initialize if needed (separate step with mutable borrow)
155- AST_CLASSES . with ( |cache| {
160+ BASE_AST_TYPE . with ( |cache| {
156161 if cache. borrow ( ) . is_none ( ) {
157- * cache. borrow_mut ( ) = Some ( compute_ast_classes ( py) ?) ;
162+ * cache. borrow_mut ( ) = Some ( get_base_ast_type ( py) ?) ;
158163 }
159164 Ok :: < ( ) , PyErr > ( ( ) )
160165 } ) ?;
161166
162167 // Now use immutable borrow for the actual work
163- AST_CLASSES . with ( |cache| {
168+ BASE_AST_TYPE . with ( |cache| {
164169 let cache_ref = cache. borrow ( ) ;
165- let all_ast_classes = cache_ref. as_ref ( ) . unwrap ( ) ;
170+ let base_ast_type = cache_ref. as_ref ( ) . unwrap ( ) ;
166171 walk_node (
167172 node. as_ptr ( ) ,
168- all_ast_classes ,
173+ base_ast_type ,
169174 PyList :: type_object_raw ( py) ,
170175 & mut result_list,
171176 ) ?;
@@ -186,3 +191,38 @@ fn fast_walk(m: &Bound<'_, PyModule>) -> PyResult<()> {
186191 m. add_function ( wrap_pyfunction ! ( walk, m) ?) ?;
187192 Ok ( ( ) )
188193}
194+
195+ #[ cfg( test) ]
196+ mod tests {
197+ use pyo3:: types:: PyDict ;
198+
199+ use super :: * ;
200+
201+ #[ test]
202+ fn test_empty_dict_no_values ( ) {
203+ Python :: initialize ( ) ;
204+
205+ Python :: attach ( |py| {
206+ let dict = PyDict :: new ( py) ;
207+ let dict_ptr = dict. as_ptr ( ) as * mut pyo3:: ffi:: PyDictObject ;
208+ let values = unsafe { DictValuesIter :: new ( dict_ptr) } . collect :: < Vec < _ > > ( ) ;
209+ assert_eq ! ( values. len( ) , 0 ) ;
210+ } ) ;
211+ }
212+
213+ #[ test]
214+ fn test_string_keys_dict_values ( ) {
215+ Python :: initialize ( ) ;
216+
217+ Python :: attach ( |py| {
218+ let dict = PyDict :: new ( py) ;
219+ dict. set_item ( "a" , 1 ) . unwrap ( ) ;
220+ dict. set_item ( "b" , 2 ) . unwrap ( ) ;
221+ dict. set_item ( "c" , 3 ) . unwrap ( ) ;
222+
223+ let dict_ptr = dict. as_ptr ( ) as * mut pyo3:: ffi:: PyDictObject ;
224+ let values = unsafe { DictValuesIter :: new ( dict_ptr) } . collect :: < Vec < _ > > ( ) ;
225+ assert_eq ! ( values. len( ) , 3 ) ;
226+ } ) ;
227+ }
228+ }
0 commit comments