Skip to content

Commit e31a03c

Browse files
committed
optimize things to use zero python calls
1 parent 377678b commit e31a03c

File tree

4 files changed

+222
-83
lines changed

4 files changed

+222
-83
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fast-walk"
3-
version = "0.1.6"
3+
version = "0.1.7"
44
edition = "2024"
55
license = "MIT"
66
readme = "README.md"

src/lib.rs

Lines changed: 121 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,57 @@
1+
mod pydict;
2+
13
use std::cell::RefCell;
24

3-
use fastset::Set;
4-
use pyo3::ffi::{self, Py_ssize_t, PyObject, PyTypeObject};
5-
use pyo3::types::{PyList, PyModule, PyString, PyType};
5+
use pyo3::ffi::{self, Py_ssize_t, PyDictObject, PyObject, PyTypeObject};
6+
use pyo3::types::{PyList, PyModule, PyType};
67
use pyo3::{PyTypeInfo, prelude::*};
78

8-
pub struct BorrowedDictIter {
9-
dict: *mut PyObject,
10-
ppos: ffi::Py_ssize_t,
11-
len: ffi::Py_ssize_t,
9+
pub struct DictValuesIter {
10+
entries: *const pydict::PyDictUnicodeEntry,
11+
current: usize,
12+
end: usize,
1213
}
1314

14-
impl Iterator for BorrowedDictIter {
15-
type Item = *mut PyObject;
16-
17-
#[inline]
18-
fn next(&mut self) -> Option<Self::Item> {
19-
let mut value: *mut PyObject = std::ptr::null_mut();
20-
21-
// Safety: self.dict lives sufficiently long that the pointer is not dangling
22-
if unsafe { ffi::PyDict_Next(self.dict, &mut self.ppos, std::ptr::null_mut(), &mut value) }
23-
!= 0
24-
{
25-
self.len -= 1;
26-
Some(value)
27-
} else {
28-
None
15+
impl DictValuesIter {
16+
/// Creates a new iterator over dictionary values
17+
///
18+
/// # Safety
19+
///
20+
/// The caller must ensure that:
21+
/// - `obj` is a valid pointer to a `PyDictObject`
22+
/// - The dictionary remains valid for the lifetime of the iterator
23+
/// - The dictionary is not modified while iterating
24+
pub unsafe fn new(obj: *mut PyDictObject) -> Self {
25+
unsafe {
26+
let dict = &*obj;
27+
let keys = &*dict.ma_keys.cast::<pydict::PyDictKeysObject>();
28+
let entries = keys.unicode_entries();
29+
let n = keys.dk_nentries as usize;
30+
31+
Self {
32+
entries,
33+
current: 0,
34+
end: n,
35+
}
2936
}
3037
}
31-
32-
#[inline]
33-
fn size_hint(&self) -> (usize, Option<usize>) {
34-
let len = self.len();
35-
(len, Some(len))
36-
}
37-
38-
#[inline]
39-
fn count(self) -> usize
40-
where
41-
Self: Sized,
42-
{
43-
self.len()
44-
}
4538
}
4639

47-
impl ExactSizeIterator for BorrowedDictIter {
48-
fn len(&self) -> usize {
49-
self.len as usize
50-
}
51-
}
40+
impl Iterator for DictValuesIter {
41+
type Item = *mut PyObject;
5242

53-
fn dict_len(dict: *mut PyObject) -> Py_ssize_t {
54-
unsafe { ffi::PyDict_Size(dict) }
55-
}
43+
fn next(&mut self) -> Option<Self::Item> {
44+
// Skip null entries until we find a valid one or reach the end
45+
while self.current < self.end {
46+
let entry = &unsafe { *self.entries.add(self.current) };
47+
self.current += 1;
48+
49+
if !entry.me_value.is_null() {
50+
return Some(entry.me_value);
51+
}
52+
}
5653

57-
impl BorrowedDictIter {
58-
pub fn new(dict: *mut PyObject) -> Self {
59-
let len = dict_len(dict);
60-
BorrowedDictIter { dict, ppos: 0, len }
54+
None
6155
}
6256
}
6357

@@ -66,8 +60,7 @@ fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> {
6660
let dict_ptr = (*obj).ob_type.as_ref()?.tp_dictoffset;
6761

6862
if dict_ptr != 0 {
69-
let dict_ptr_addr =
70-
(obj as *mut u8).offset(dict_ptr as isize) as *mut *mut ffi::PyObject;
63+
let dict_ptr_addr = (obj as *mut u8).offset(dict_ptr) as *mut *mut ffi::PyObject;
7164
let dict = *dict_ptr_addr;
7265

7366
if !dict.is_null() {
@@ -78,12 +71,35 @@ fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> {
7871
}
7972
}
8073

81-
fn isinstance_of_ast(obj: *mut PyObject, all_ast_classes: &Set) -> bool {
82-
let el = unsafe { ffi::Py_TYPE(obj) };
83-
all_ast_classes.contains(&(el as usize))
74+
unsafe fn is_subtype(
75+
subtype: *mut pyo3::ffi::PyTypeObject,
76+
base: *mut pyo3::ffi::PyTypeObject,
77+
) -> bool {
78+
// If they're the same type, it's trivially a subtype
79+
if subtype == base {
80+
return true;
81+
}
82+
83+
// Walk up the inheritance chain via tp_base, max 3 jumps
84+
let mut current = subtype;
85+
for _ in 0..3 {
86+
current = unsafe { (*current).tp_base };
87+
if current.is_null() {
88+
return false;
89+
}
90+
if current == base {
91+
return true;
92+
}
93+
}
94+
95+
false
96+
}
97+
98+
fn isinstance_of_ast(obj: *mut PyObject, base_ast_type: &*mut PyTypeObject) -> bool {
99+
unsafe { is_subtype(ffi::Py_TYPE(obj), *base_ast_type) }
84100
}
85101

86-
fn isinstance_of_list(obj: *mut PyObject, py_list_type: *mut PyTypeObject) -> bool {
102+
fn is_list(obj: *mut PyObject, py_list_type: *mut PyTypeObject) -> bool {
87103
unsafe { ffi::Py_TYPE(obj) == py_list_type }
88104
}
89105

@@ -95,9 +111,9 @@ fn get_item_of_list(obj: *mut PyObject, index: Py_ssize_t) -> *mut PyObject {
95111
unsafe { ffi::PyList_GET_ITEM(obj, index) }
96112
}
97113

98-
fn walk_node<'py>(
114+
fn walk_node(
99115
node: *mut PyObject,
100-
all_ast_classes: &Set,
116+
base_ast_type: &*mut PyTypeObject,
101117
py_list_type: *mut PyTypeObject,
102118
result_list: &mut Vec<*mut PyObject>,
103119
) -> PyResult<()> {
@@ -107,16 +123,15 @@ fn walk_node<'py>(
107123
let Some(dict) = get_instance_dict_fast(node) else {
108124
return Ok(());
109125
};
110-
let values = BorrowedDictIter::new(dict);
111-
for item_ptr in values {
112-
if isinstance_of_ast(item_ptr, all_ast_classes) {
113-
walk_node(item_ptr, all_ast_classes, py_list_type, result_list)?;
114-
} else if isinstance_of_list(item_ptr, py_list_type) {
126+
for item_ptr in unsafe { DictValuesIter::new(dict.cast::<PyDictObject>()) } {
127+
if isinstance_of_ast(item_ptr, base_ast_type) {
128+
walk_node(item_ptr, base_ast_type, py_list_type, result_list)?;
129+
} else if is_list(item_ptr, py_list_type) {
115130
let length = get_length_of_list(item_ptr);
116131
for i in 0..length {
117132
let item_ptr = get_item_of_list(item_ptr, i);
118-
if isinstance_of_ast(item_ptr, all_ast_classes) {
119-
walk_node(item_ptr, all_ast_classes, py_list_type, result_list)?;
133+
if isinstance_of_ast(item_ptr, base_ast_type) {
134+
walk_node(item_ptr, base_ast_type, py_list_type, result_list)?;
120135
}
121136
}
122137
}
@@ -126,46 +141,36 @@ fn walk_node<'py>(
126141
}
127142

128143
thread_local! {
129-
static AST_CLASSES: RefCell<Option<Set>> = RefCell::new(None);
144+
static BASE_AST_TYPE: RefCell<Option<*mut PyTypeObject>> = const { RefCell::new(None) };
130145
}
131146

132147
#[inline(never)]
133-
fn compute_ast_classes<'py>(py: Python<'py>) -> PyResult<Set> {
134-
let mut classes = Vec::new();
148+
fn get_base_ast_type<'py>(py: Python<'py>) -> PyResult<*mut PyTypeObject> {
135149
let ast_module = py.import("ast")?;
136150
let ast_class = ast_module.getattr("AST")?.cast_into::<PyType>()?;
137151

138-
for field in ast_module.dir()? {
139-
let class = ast_module.getattr(field.cast_exact::<PyString>()?)?;
140-
if let Ok(class) = class.cast_into::<PyType>() {
141-
if class.is_subclass(ast_class.cast::<PyType>()?)? {
142-
classes.push(class.as_type_ptr() as usize);
143-
}
144-
}
145-
}
146-
147-
Ok(Set::from(classes))
152+
Ok(ast_class.as_type_ptr())
148153
}
149154

150155
#[pyfunction]
151156
fn walk<'py>(py: Python, node: Bound<'py, PyAny>) -> PyResult<Py<PyList>> {
152157
let mut result_list = Vec::new();
153158

154159
// Initialize if needed (separate step with mutable borrow)
155-
AST_CLASSES.with(|cache| {
160+
BASE_AST_TYPE.with(|cache| {
156161
if cache.borrow().is_none() {
157-
*cache.borrow_mut() = Some(compute_ast_classes(py)?);
162+
*cache.borrow_mut() = Some(get_base_ast_type(py)?);
158163
}
159164
Ok::<(), PyErr>(())
160165
})?;
161166

162167
// Now use immutable borrow for the actual work
163-
AST_CLASSES.with(|cache| {
168+
BASE_AST_TYPE.with(|cache| {
164169
let cache_ref = cache.borrow();
165-
let all_ast_classes = cache_ref.as_ref().unwrap();
170+
let base_ast_type = cache_ref.as_ref().unwrap();
166171
walk_node(
167172
node.as_ptr(),
168-
all_ast_classes,
173+
base_ast_type,
169174
PyList::type_object_raw(py),
170175
&mut result_list,
171176
)?;
@@ -186,3 +191,38 @@ fn fast_walk(m: &Bound<'_, PyModule>) -> PyResult<()> {
186191
m.add_function(wrap_pyfunction!(walk, m)?)?;
187192
Ok(())
188193
}
194+
195+
#[cfg(test)]
196+
mod tests {
197+
use pyo3::types::PyDict;
198+
199+
use super::*;
200+
201+
#[test]
202+
fn test_empty_dict_no_values() {
203+
Python::initialize();
204+
205+
Python::attach(|py| {
206+
let dict = PyDict::new(py);
207+
let dict_ptr = dict.as_ptr() as *mut pyo3::ffi::PyDictObject;
208+
let values = unsafe { DictValuesIter::new(dict_ptr) }.collect::<Vec<_>>();
209+
assert_eq!(values.len(), 0);
210+
});
211+
}
212+
213+
#[test]
214+
fn test_string_keys_dict_values() {
215+
Python::initialize();
216+
217+
Python::attach(|py| {
218+
let dict = PyDict::new(py);
219+
dict.set_item("a", 1).unwrap();
220+
dict.set_item("b", 2).unwrap();
221+
dict.set_item("c", 3).unwrap();
222+
223+
let dict_ptr = dict.as_ptr() as *mut pyo3::ffi::PyDictObject;
224+
let values = unsafe { DictValuesIter::new(dict_ptr) }.collect::<Vec<_>>();
225+
assert_eq!(values.len(), 3);
226+
});
227+
}
228+
}

0 commit comments

Comments
 (0)