Skip to content

Commit 377678b

Browse files
committed
optimize things so much more it's not even funny
1 parent 6ca15bd commit 377678b

File tree

7 files changed

+399
-46
lines changed

7 files changed

+399
-46
lines changed

Cargo.lock

Lines changed: 48 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fast-walk"
3-
version = "0.1.5"
3+
version = "0.1.6"
44
edition = "2024"
55
license = "MIT"
66
readme = "README.md"
@@ -17,4 +17,5 @@ lto = true # Link-time optimization.
1717
codegen-units = 1 # Slower compilation but faster code.
1818

1919
[dependencies]
20+
fastset = "0.5.2"
2021
pyo3 = "0.27.1"

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,8 @@ cache-keys = [
3232
]
3333

3434
[dependency-groups]
35-
dev = ["pytest>=8.4.2", "pytest-codspeed>=4.2.0"]
35+
dev = [
36+
"py-spy>=0.4.1",
37+
"pytest>=8.4.2",
38+
"pytest-codspeed>=4.2.0",
39+
]

src/lib.rs

Lines changed: 161 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,184 @@
1-
use pyo3::types::{PyList, PyModule, PyTuple};
2-
use pyo3::{intern, prelude::*};
3-
4-
fn getattr<'py>(
5-
obj: &Bound<'py, PyAny>,
6-
attr_name: &Bound<'py, PyAny>,
7-
) -> Option<Bound<'py, PyAny>> {
8-
let py = obj.py();
9-
10-
let mut resp_ptr: *mut pyo3::ffi::PyObject = std::ptr::null_mut();
11-
let attr_ptr = unsafe {
12-
pyo3::ffi::PyObject_GetOptionalAttr(obj.as_ptr(), attr_name.as_ptr(), &mut resp_ptr)
13-
};
1+
use std::cell::RefCell;
2+
3+
use fastset::Set;
4+
use pyo3::ffi::{self, Py_ssize_t, PyObject, PyTypeObject};
5+
use pyo3::types::{PyList, PyModule, PyString, PyType};
6+
use pyo3::{PyTypeInfo, prelude::*};
7+
8+
pub struct BorrowedDictIter {
9+
dict: *mut PyObject,
10+
ppos: ffi::Py_ssize_t,
11+
len: ffi::Py_ssize_t,
12+
}
13+
14+
impl Iterator for BorrowedDictIter {
15+
type Item = *mut PyObject;
16+
17+
#[inline]
18+
fn next(&mut self) -> Option<Self::Item> {
19+
let mut value: *mut PyObject = std::ptr::null_mut();
20+
21+
// Safety: self.dict lives sufficiently long that the pointer is not dangling
22+
if unsafe { ffi::PyDict_Next(self.dict, &mut self.ppos, std::ptr::null_mut(), &mut value) }
23+
!= 0
24+
{
25+
self.len -= 1;
26+
Some(value)
27+
} else {
28+
None
29+
}
30+
}
31+
32+
#[inline]
33+
fn size_hint(&self) -> (usize, Option<usize>) {
34+
let len = self.len();
35+
(len, Some(len))
36+
}
37+
38+
#[inline]
39+
fn count(self) -> usize
40+
where
41+
Self: Sized,
42+
{
43+
self.len()
44+
}
45+
}
46+
47+
impl ExactSizeIterator for BorrowedDictIter {
48+
fn len(&self) -> usize {
49+
self.len as usize
50+
}
51+
}
52+
53+
fn dict_len(dict: *mut PyObject) -> Py_ssize_t {
54+
unsafe { ffi::PyDict_Size(dict) }
55+
}
56+
57+
impl BorrowedDictIter {
58+
pub fn new(dict: *mut PyObject) -> Self {
59+
let len = dict_len(dict);
60+
BorrowedDictIter { dict, ppos: 0, len }
61+
}
62+
}
63+
64+
fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> {
65+
unsafe {
66+
let dict_ptr = (*obj).ob_type.as_ref()?.tp_dictoffset;
67+
68+
if dict_ptr != 0 {
69+
let dict_ptr_addr =
70+
(obj as *mut u8).offset(dict_ptr as isize) as *mut *mut ffi::PyObject;
71+
let dict = *dict_ptr_addr;
1472

15-
if attr_ptr == 1 {
16-
Some(unsafe { Bound::from_owned_ptr(py, resp_ptr) })
17-
} else {
73+
if !dict.is_null() {
74+
return Some(dict);
75+
}
76+
}
1877
None
1978
}
2079
}
2180

81+
fn isinstance_of_ast(obj: *mut PyObject, all_ast_classes: &Set) -> bool {
82+
let el = unsafe { ffi::Py_TYPE(obj) };
83+
all_ast_classes.contains(&(el as usize))
84+
}
85+
86+
fn isinstance_of_list(obj: *mut PyObject, py_list_type: *mut PyTypeObject) -> bool {
87+
unsafe { ffi::Py_TYPE(obj) == py_list_type }
88+
}
89+
90+
fn get_length_of_list(obj: *mut PyObject) -> Py_ssize_t {
91+
unsafe { ffi::PyList_GET_SIZE(obj) }
92+
}
93+
94+
fn get_item_of_list(obj: *mut PyObject, index: Py_ssize_t) -> *mut PyObject {
95+
unsafe { ffi::PyList_GET_ITEM(obj, index) }
96+
}
97+
2298
fn walk_node<'py>(
23-
node: Bound<'py, PyAny>,
24-
field_names: Bound<'py, PyTuple>,
25-
result_list: &mut Vec<Bound<'py, PyAny>>,
99+
node: *mut PyObject,
100+
all_ast_classes: &Set,
101+
py_list_type: *mut PyTypeObject,
102+
result_list: &mut Vec<*mut PyObject>,
26103
) -> PyResult<()> {
27-
result_list.push(node.clone());
104+
result_list.push(node);
28105

29106
// Recursively walk through child nodes
30-
for field in field_names {
31-
if let Some(child) = getattr(&node, &field) {
32-
if child.is_exact_instance_of::<PyList>() {
33-
for item in unsafe { child.cast_unchecked::<PyList>() } {
34-
if let Some(subfields) = getattr(&item, intern!(item.py(), "_fields")) {
35-
walk_node(
36-
item,
37-
unsafe { subfields.cast_into_unchecked::<PyTuple>() },
38-
result_list,
39-
)?;
40-
}
107+
let Some(dict) = get_instance_dict_fast(node) else {
108+
return Ok(());
109+
};
110+
let values = BorrowedDictIter::new(dict);
111+
for item_ptr in values {
112+
if isinstance_of_ast(item_ptr, all_ast_classes) {
113+
walk_node(item_ptr, all_ast_classes, py_list_type, result_list)?;
114+
} else if isinstance_of_list(item_ptr, py_list_type) {
115+
let length = get_length_of_list(item_ptr);
116+
for i in 0..length {
117+
let item_ptr = get_item_of_list(item_ptr, i);
118+
if isinstance_of_ast(item_ptr, all_ast_classes) {
119+
walk_node(item_ptr, all_ast_classes, py_list_type, result_list)?;
41120
}
42-
} else if let Some(subfields) = getattr(&child, intern!(child.py(), "_fields")) {
43-
walk_node(
44-
child,
45-
unsafe { subfields.cast_into_unchecked::<PyTuple>() },
46-
result_list,
47-
)?;
48121
}
49122
}
50123
}
51124

52125
Ok(())
53126
}
54127

128+
thread_local! {
129+
static AST_CLASSES: RefCell<Option<Set>> = RefCell::new(None);
130+
}
131+
132+
#[inline(never)]
133+
fn compute_ast_classes<'py>(py: Python<'py>) -> PyResult<Set> {
134+
let mut classes = Vec::new();
135+
let ast_module = py.import("ast")?;
136+
let ast_class = ast_module.getattr("AST")?.cast_into::<PyType>()?;
137+
138+
for field in ast_module.dir()? {
139+
let class = ast_module.getattr(field.cast_exact::<PyString>()?)?;
140+
if let Ok(class) = class.cast_into::<PyType>() {
141+
if class.is_subclass(ast_class.cast::<PyType>()?)? {
142+
classes.push(class.as_type_ptr() as usize);
143+
}
144+
}
145+
}
146+
147+
Ok(Set::from(classes))
148+
}
149+
55150
#[pyfunction]
56151
fn walk<'py>(py: Python, node: Bound<'py, PyAny>) -> PyResult<Py<PyList>> {
57152
let mut result_list = Vec::new();
58-
let fields = node.getattr(intern!(py, "_fields"))?;
59-
let fields = unsafe { fields.cast_into_unchecked::<PyTuple>() };
60-
walk_node(node, fields, &mut result_list)?;
61-
Ok(PyList::new(py, result_list)?.into())
153+
154+
// Initialize if needed (separate step with mutable borrow)
155+
AST_CLASSES.with(|cache| {
156+
if cache.borrow().is_none() {
157+
*cache.borrow_mut() = Some(compute_ast_classes(py)?);
158+
}
159+
Ok::<(), PyErr>(())
160+
})?;
161+
162+
// Now use immutable borrow for the actual work
163+
AST_CLASSES.with(|cache| {
164+
let cache_ref = cache.borrow();
165+
let all_ast_classes = cache_ref.as_ref().unwrap();
166+
walk_node(
167+
node.as_ptr(),
168+
all_ast_classes,
169+
PyList::type_object_raw(py),
170+
&mut result_list,
171+
)?;
172+
Ok::<(), PyErr>(())
173+
})?;
174+
175+
Ok(PyList::new(
176+
py,
177+
result_list
178+
.into_iter()
179+
.map(|ptr| unsafe { Bound::from_borrowed_ptr(py, ptr) }),
180+
)?
181+
.into())
62182
}
63183

64184
#[pymodule]

0 commit comments

Comments
 (0)