|
1 | | -use pyo3::types::{PyList, PyModule, PyTuple}; |
2 | | -use pyo3::{intern, prelude::*}; |
3 | | - |
4 | | -fn getattr<'py>( |
5 | | - obj: &Bound<'py, PyAny>, |
6 | | - attr_name: &Bound<'py, PyAny>, |
7 | | -) -> Option<Bound<'py, PyAny>> { |
8 | | - let py = obj.py(); |
9 | | - |
10 | | - let mut resp_ptr: *mut pyo3::ffi::PyObject = std::ptr::null_mut(); |
11 | | - let attr_ptr = unsafe { |
12 | | - pyo3::ffi::PyObject_GetOptionalAttr(obj.as_ptr(), attr_name.as_ptr(), &mut resp_ptr) |
13 | | - }; |
| 1 | +use std::cell::RefCell; |
| 2 | + |
| 3 | +use fastset::Set; |
| 4 | +use pyo3::ffi::{self, Py_ssize_t, PyObject, PyTypeObject}; |
| 5 | +use pyo3::types::{PyList, PyModule, PyString, PyType}; |
| 6 | +use pyo3::{PyTypeInfo, prelude::*}; |
| 7 | + |
| 8 | +pub struct BorrowedDictIter { |
| 9 | + dict: *mut PyObject, |
| 10 | + ppos: ffi::Py_ssize_t, |
| 11 | + len: ffi::Py_ssize_t, |
| 12 | +} |
| 13 | + |
| 14 | +impl Iterator for BorrowedDictIter { |
| 15 | + type Item = *mut PyObject; |
| 16 | + |
| 17 | + #[inline] |
| 18 | + fn next(&mut self) -> Option<Self::Item> { |
| 19 | + let mut value: *mut PyObject = std::ptr::null_mut(); |
| 20 | + |
| 21 | + // Safety: self.dict lives sufficiently long that the pointer is not dangling |
| 22 | + if unsafe { ffi::PyDict_Next(self.dict, &mut self.ppos, std::ptr::null_mut(), &mut value) } |
| 23 | + != 0 |
| 24 | + { |
| 25 | + self.len -= 1; |
| 26 | + Some(value) |
| 27 | + } else { |
| 28 | + None |
| 29 | + } |
| 30 | + } |
| 31 | + |
| 32 | + #[inline] |
| 33 | + fn size_hint(&self) -> (usize, Option<usize>) { |
| 34 | + let len = self.len(); |
| 35 | + (len, Some(len)) |
| 36 | + } |
| 37 | + |
| 38 | + #[inline] |
| 39 | + fn count(self) -> usize |
| 40 | + where |
| 41 | + Self: Sized, |
| 42 | + { |
| 43 | + self.len() |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +impl ExactSizeIterator for BorrowedDictIter { |
| 48 | + fn len(&self) -> usize { |
| 49 | + self.len as usize |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +fn dict_len(dict: *mut PyObject) -> Py_ssize_t { |
| 54 | + unsafe { ffi::PyDict_Size(dict) } |
| 55 | +} |
| 56 | + |
| 57 | +impl BorrowedDictIter { |
| 58 | + pub fn new(dict: *mut PyObject) -> Self { |
| 59 | + let len = dict_len(dict); |
| 60 | + BorrowedDictIter { dict, ppos: 0, len } |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> { |
| 65 | + unsafe { |
| 66 | + let dict_ptr = (*obj).ob_type.as_ref()?.tp_dictoffset; |
| 67 | + |
| 68 | + if dict_ptr != 0 { |
| 69 | + let dict_ptr_addr = |
| 70 | + (obj as *mut u8).offset(dict_ptr as isize) as *mut *mut ffi::PyObject; |
| 71 | + let dict = *dict_ptr_addr; |
14 | 72 |
|
15 | | - if attr_ptr == 1 { |
16 | | - Some(unsafe { Bound::from_owned_ptr(py, resp_ptr) }) |
17 | | - } else { |
| 73 | + if !dict.is_null() { |
| 74 | + return Some(dict); |
| 75 | + } |
| 76 | + } |
18 | 77 | None |
19 | 78 | } |
20 | 79 | } |
21 | 80 |
|
| 81 | +fn isinstance_of_ast(obj: *mut PyObject, all_ast_classes: &Set) -> bool { |
| 82 | + let el = unsafe { ffi::Py_TYPE(obj) }; |
| 83 | + all_ast_classes.contains(&(el as usize)) |
| 84 | +} |
| 85 | + |
| 86 | +fn isinstance_of_list(obj: *mut PyObject, py_list_type: *mut PyTypeObject) -> bool { |
| 87 | + unsafe { ffi::Py_TYPE(obj) == py_list_type } |
| 88 | +} |
| 89 | + |
| 90 | +fn get_length_of_list(obj: *mut PyObject) -> Py_ssize_t { |
| 91 | + unsafe { ffi::PyList_GET_SIZE(obj) } |
| 92 | +} |
| 93 | + |
| 94 | +fn get_item_of_list(obj: *mut PyObject, index: Py_ssize_t) -> *mut PyObject { |
| 95 | + unsafe { ffi::PyList_GET_ITEM(obj, index) } |
| 96 | +} |
| 97 | + |
22 | 98 | fn walk_node<'py>( |
23 | | - node: Bound<'py, PyAny>, |
24 | | - field_names: Bound<'py, PyTuple>, |
25 | | - result_list: &mut Vec<Bound<'py, PyAny>>, |
| 99 | + node: *mut PyObject, |
| 100 | + all_ast_classes: &Set, |
| 101 | + py_list_type: *mut PyTypeObject, |
| 102 | + result_list: &mut Vec<*mut PyObject>, |
26 | 103 | ) -> PyResult<()> { |
27 | | - result_list.push(node.clone()); |
| 104 | + result_list.push(node); |
28 | 105 |
|
29 | 106 | // Recursively walk through child nodes |
30 | | - for field in field_names { |
31 | | - if let Some(child) = getattr(&node, &field) { |
32 | | - if child.is_exact_instance_of::<PyList>() { |
33 | | - for item in unsafe { child.cast_unchecked::<PyList>() } { |
34 | | - if let Some(subfields) = getattr(&item, intern!(item.py(), "_fields")) { |
35 | | - walk_node( |
36 | | - item, |
37 | | - unsafe { subfields.cast_into_unchecked::<PyTuple>() }, |
38 | | - result_list, |
39 | | - )?; |
40 | | - } |
| 107 | + let Some(dict) = get_instance_dict_fast(node) else { |
| 108 | + return Ok(()); |
| 109 | + }; |
| 110 | + let values = BorrowedDictIter::new(dict); |
| 111 | + for item_ptr in values { |
| 112 | + if isinstance_of_ast(item_ptr, all_ast_classes) { |
| 113 | + walk_node(item_ptr, all_ast_classes, py_list_type, result_list)?; |
| 114 | + } else if isinstance_of_list(item_ptr, py_list_type) { |
| 115 | + let length = get_length_of_list(item_ptr); |
| 116 | + for i in 0..length { |
| 117 | + let item_ptr = get_item_of_list(item_ptr, i); |
| 118 | + if isinstance_of_ast(item_ptr, all_ast_classes) { |
| 119 | + walk_node(item_ptr, all_ast_classes, py_list_type, result_list)?; |
41 | 120 | } |
42 | | - } else if let Some(subfields) = getattr(&child, intern!(child.py(), "_fields")) { |
43 | | - walk_node( |
44 | | - child, |
45 | | - unsafe { subfields.cast_into_unchecked::<PyTuple>() }, |
46 | | - result_list, |
47 | | - )?; |
48 | 121 | } |
49 | 122 | } |
50 | 123 | } |
51 | 124 |
|
52 | 125 | Ok(()) |
53 | 126 | } |
54 | 127 |
|
| 128 | +thread_local! { |
| 129 | + static AST_CLASSES: RefCell<Option<Set>> = RefCell::new(None); |
| 130 | +} |
| 131 | + |
| 132 | +#[inline(never)] |
| 133 | +fn compute_ast_classes<'py>(py: Python<'py>) -> PyResult<Set> { |
| 134 | + let mut classes = Vec::new(); |
| 135 | + let ast_module = py.import("ast")?; |
| 136 | + let ast_class = ast_module.getattr("AST")?.cast_into::<PyType>()?; |
| 137 | + |
| 138 | + for field in ast_module.dir()? { |
| 139 | + let class = ast_module.getattr(field.cast_exact::<PyString>()?)?; |
| 140 | + if let Ok(class) = class.cast_into::<PyType>() { |
| 141 | + if class.is_subclass(ast_class.cast::<PyType>()?)? { |
| 142 | + classes.push(class.as_type_ptr() as usize); |
| 143 | + } |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + Ok(Set::from(classes)) |
| 148 | +} |
| 149 | + |
55 | 150 | #[pyfunction] |
56 | 151 | fn walk<'py>(py: Python, node: Bound<'py, PyAny>) -> PyResult<Py<PyList>> { |
57 | 152 | let mut result_list = Vec::new(); |
58 | | - let fields = node.getattr(intern!(py, "_fields"))?; |
59 | | - let fields = unsafe { fields.cast_into_unchecked::<PyTuple>() }; |
60 | | - walk_node(node, fields, &mut result_list)?; |
61 | | - Ok(PyList::new(py, result_list)?.into()) |
| 153 | + |
| 154 | + // Initialize if needed (separate step with mutable borrow) |
| 155 | + AST_CLASSES.with(|cache| { |
| 156 | + if cache.borrow().is_none() { |
| 157 | + *cache.borrow_mut() = Some(compute_ast_classes(py)?); |
| 158 | + } |
| 159 | + Ok::<(), PyErr>(()) |
| 160 | + })?; |
| 161 | + |
| 162 | + // Now use immutable borrow for the actual work |
| 163 | + AST_CLASSES.with(|cache| { |
| 164 | + let cache_ref = cache.borrow(); |
| 165 | + let all_ast_classes = cache_ref.as_ref().unwrap(); |
| 166 | + walk_node( |
| 167 | + node.as_ptr(), |
| 168 | + all_ast_classes, |
| 169 | + PyList::type_object_raw(py), |
| 170 | + &mut result_list, |
| 171 | + )?; |
| 172 | + Ok::<(), PyErr>(()) |
| 173 | + })?; |
| 174 | + |
| 175 | + Ok(PyList::new( |
| 176 | + py, |
| 177 | + result_list |
| 178 | + .into_iter() |
| 179 | + .map(|ptr| unsafe { Bound::from_borrowed_ptr(py, ptr) }), |
| 180 | + )? |
| 181 | + .into()) |
62 | 182 | } |
63 | 183 |
|
64 | 184 | #[pymodule] |
|
0 commit comments