Skip to content

Commit 71a6b22

Browse files
committed
feat: introduce deep models chain detection
prefetchs like `relation__child__childchild` are supported where `child` is a child relation from `relation` and when we access to relation.child, the prefetch is detected for the relation, no warnings are triggered, same case for child.childchild
1 parent 333bab1 commit 71a6b22

File tree

4 files changed

+141
-71
lines changed

4 files changed

+141
-71
lines changed

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,13 @@ for user in users:
178178

179179
* Interprocedural analysis requires type hints on QuerySet parameters
180180
* Limited understanding of:
181-
* `Prefetch` objects
181+
* `Prefetch` objects with basic access support (e.g. `Prefetch("related_model")`)
182182
* `annotate`, `aggregate`
183183
* complex custom managers
184-
* No complete capture of implicit prefetchs in django chains when has many. e.g.
185-
`prefetch_releated(chat.users__profile)`, then iterate over `chat.users` and
186-
access to `profile`. This will raise a warning.
187184

188185
## Roadmap
189186

190-
* Prefetch object support
191-
* Capture of implicits prefetchs in a root instance
187+
* Complete Prefetch object support
192188
* Custom queryset method summaries
193189
* Templates integration
194190

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
//! Binding-level intermediate representation for tracking variable states.
22
3-
use std::collections::HashSet;
3+
use std::{collections::HashMap, str::Split};
4+
5+
use crate::{ModelGraph, ir::model::ModelDef};
46

57
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
68
pub struct DjangoSymbolId(pub u32);
@@ -14,21 +16,21 @@ impl DjangoSymbolId {
1416
#[derive(Debug, Clone, Eq, PartialEq)]
1517
pub struct QuerySetState {
1618
pub model_name: String,
17-
pub prefetched_relations: HashSet<String>,
19+
pub prefetched_relations: HashMap<String, QuerySetState>,
1820
pub is_values_query: bool,
1921
}
2022

2123
impl QuerySetState {
2224
pub fn new(model_name: String) -> Self {
2325
Self {
2426
model_name,
25-
prefetched_relations: HashSet::new(),
27+
prefetched_relations: HashMap::new(),
2628
is_values_query: false,
2729
}
2830
}
2931

3032
pub fn is_access_safe(&self, relation: &str) -> bool {
31-
self.is_values_query || self.prefetched_relations.contains(relation)
33+
self.is_values_query || self.prefetched_relations.contains_key(relation)
3234
}
3335
}
3436

@@ -50,46 +52,78 @@ pub enum DjangoSymbolKind {
5052
Unknown,
5153
}
5254

53-
/// Parse Django relation field syntax (e.g., "author__profile" -> ["author", "profile"])
54-
pub fn parse_relation_fields(fields: &[String]) -> Vec<String> {
55-
fields
56-
.iter()
57-
.flat_map(|field| field.split("__"))
58-
.map(|s| s.to_string())
59-
.collect()
55+
/// Parse Django relation field syntax (e.g., "author__profile" ->
56+
/// [PrefetchedRelation(QuerySetState)] with `profile` prefetched)
57+
pub fn parse_relation_fields(
58+
model: &ModelDef,
59+
model_graph: &ModelGraph,
60+
fields: &[String],
61+
) -> HashMap<String, QuerySetState> {
62+
let mut relations = HashMap::new();
63+
for field in fields.iter() {
64+
let parts = field.split("__");
65+
if let Some((literal, relation)) = bind_relation(model, model_graph, parts) {
66+
relations.insert(literal, relation);
67+
}
68+
}
69+
70+
relations
71+
}
72+
73+
fn bind_relation(
74+
model: &ModelDef,
75+
model_graph: &ModelGraph,
76+
mut parts: Split<'_, &str>,
77+
) -> Option<(String, QuerySetState)> {
78+
if let Some(base) = parts.next()
79+
&& let Some(relation) = model_graph.get_relation(&model.name, base)
80+
&& let Some(related_model) = model_graph.get(relation)
81+
{
82+
let mut qs = QuerySetState::new(related_model.name.to_string());
83+
84+
// Insert child relation
85+
if let Some((literal, relation)) = bind_relation(related_model, model_graph, parts) {
86+
qs.prefetched_relations.insert(literal, relation);
87+
};
88+
return Some((base.to_string(), qs));
89+
}
90+
91+
None
6092
}
6193

6294
#[cfg(test)]
6395
mod tests {
96+
use crate::Parser;
97+
6498
use super::*;
6599

66-
#[test]
67-
fn parse_single_field() {
68-
let fields = vec!["ticker".to_string()];
69-
assert_eq!(parse_relation_fields(&fields), vec!["ticker"]);
70-
}
100+
fn get_graph() -> ModelGraph {
101+
let source = r#"
102+
class User(Model):
103+
pass
71104
72-
#[test]
73-
fn parse_nested_field() {
74-
let fields = vec!["theoanalysis__ticker".to_string()];
75-
assert_eq!(
76-
parse_relation_fields(&fields),
77-
vec!["theoanalysis", "ticker"]
78-
);
79-
}
105+
class Photo(Model):
106+
user = models.ForeignKey(User, related_name="photos")
80107
81-
#[test]
82-
fn parse_deeply_nested() {
83-
let fields = vec!["a__b__c__d".to_string()];
84-
assert_eq!(parse_relation_fields(&fields), vec!["a", "b", "c", "d"]);
85-
}
86108
87-
#[test]
88-
fn parse_multiple_fields() {
89-
let fields = vec!["ticker__sector".to_string(), "analysis__report".to_string()];
90-
assert_eq!(
91-
parse_relation_fields(&fields),
92-
vec!["ticker", "sector", "analysis", "report"]
93-
);
109+
class Order(Model):
110+
user = models.ForeignKey(User, related_name="orders")
111+
112+
class Order(Model):
113+
user = models.ForeignKey(User, related_name="orders")
114+
115+
class Sale(Model):
116+
order = models.ForeignKey(Sale, related_name="sales")
117+
118+
class Transaction(Model):
119+
sale = models.ForeignKey(Sale, related_name="transactions")
120+
121+
122+
users = User.objects.all().prefetch_related(Prefetch("orders"))
123+
print([user.orders for user in users])
124+
"#;
125+
126+
let parser = Parser::new();
127+
parser.build_graph(source, "test.py").unwrap()
94128
}
95129
}

crates/django-check_semantic/src/parser.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,29 +90,33 @@ impl Parser {
9090

9191
for entry in Self::python_files(dir) {
9292
let filename = Self::relative_path(dir, entry.path());
93-
let source = match fs::read_to_string(entry.path()) {
93+
let source = fs::read_to_string(entry.path())?;
94+
let graph = match self.build_graph(&source, &filename) {
9495
Ok(source) => source,
9596
Err(err) => {
9697
warn!(%err, source=%filename, "parsing graph");
9798
continue;
9899
}
99100
};
100-
let parsed = match self.parse_module(&source) {
101-
Ok(parsed) => parsed,
102-
Err(err) => {
103-
warn!(%err, source=%filename, "parsing graph");
104-
continue;
105-
}
106-
};
107-
108-
let mut pass = ModelGraphPass::new(&filename, &source);
109-
let graph = pass.run(parsed.syntax());
110101
combined_graph.merge(graph);
111102
}
112103

113104
Ok(combined_graph)
114105
}
115106

107+
pub fn build_graph(
108+
&self,
109+
source: &str,
110+
filename: &str,
111+
) -> Result<ModelGraph, SourceParseError> {
112+
let parsed = self
113+
.parse_module(source)
114+
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "parsing module"))?;
115+
116+
let mut pass = ModelGraphPass::new(&filename, &source);
117+
Ok(pass.run(parsed.syntax()))
118+
}
119+
116120
pub fn extract_functions(&self, dir: &Path) -> Result<Vec<QueryFunction>, SourceParseError> {
117121
let mut functions = Vec::new();
118122
for entry in Self::python_files(dir) {

crates/django-check_semantic/src/passes/n_plus_one.rs

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,25 @@ impl<'a> QuerySetResolver<'a> {
154154
// Accessing `user.orders` -> Returns a QuerySet for the related model
155155
// Accessing an attribute on a queryset (e.g., `orders.users`)
156156
// If it's a relation on the model, return a QuerySet of the related model
157-
if let Some(related_model) = self
158-
.model_graph
159-
.get_relation(&state.model_name, &attr.attr.id)
160-
{
161-
let mut new_state = QuerySetState::new(related_model.to_string());
162-
// FIXME: Currently we use the current relations to be inherited by the related
163-
// model, probably a better idea is to detect dynamically the appropiates
164-
// prefetched relations but for now this is ok.
165-
new_state
166-
.prefetched_relations
167-
.extend(state.prefetched_relations);
168-
return Some(DjangoSymbol::QuerySet(new_state));
157+
// firs try to get the prefetched query set state, if it is not prefetched, create
158+
// an empty query set state
159+
match state.prefetched_relations.get(&attr.attr.id.to_string()) {
160+
Some(related_state) => {
161+
return Some(DjangoSymbol::QuerySet(related_state.clone()));
162+
}
163+
None => {
164+
if let Some(related_model) = self
165+
.model_graph
166+
.get_relation(&state.model_name, &attr.attr.id)
167+
{
168+
// Empty state, no prefetch related detected before for this model
169+
let new_state = QuerySetState::new(related_model.to_string());
170+
return Some(DjangoSymbol::QuerySet(new_state));
171+
}
172+
// Could be accessing a regular field, return None
173+
return None;
174+
}
169175
}
170-
// Could be accessing a regular field, return None
171-
return None;
172176
}
173177
}
174178
}
@@ -223,9 +227,15 @@ impl<'a> QuerySetResolver<'a> {
223227
}
224228
}
225229

226-
// Helper from binding.rs
227-
let relations = crate::ir::binding::parse_relation_fields(&literal_fields);
228-
state.prefetched_relations.extend(relations);
230+
// Resolve prefetched relations
231+
if let Some(model) = self.model_graph.get(&state.model_name) {
232+
let relations = crate::ir::binding::parse_relation_fields(
233+
model,
234+
self.model_graph,
235+
&literal_fields,
236+
);
237+
state.prefetched_relations.extend(relations);
238+
}
229239
}
230240
}
231241

@@ -838,8 +848,11 @@ for item in qs:
838848
#[test]
839849
fn safe_with_prefetch() {
840850
let source = r#"
851+
class App(Model):
852+
pass
853+
841854
class User(Model):
842-
related_field = models.ForeignKey("some.App")
855+
related_field = models.ForeignKey("App")
843856
844857
qs = User.objects.filter(active=True).prefetch_related('related_field')
845858
for item in qs:
@@ -986,6 +999,29 @@ for p in performances:
986999
assert_eq!(diags.len(), 1);
9871000
}
9881001

1002+
#[test]
1003+
fn allow_using_deep_chained_prefetch() {
1004+
let source = r#"
1005+
class Analysis(Model):
1006+
pattern = models.ForeignKey("Pattern", related_name="analyses")
1007+
1008+
class Performance(Model):
1009+
analysis = models.ForeignKey("Analysis", related_name="performances")
1010+
1011+
class Pattern(Model):
1012+
name = models.CharField(max_length=20)
1013+
1014+
1015+
patterns = Pattern.objects.prefetch_related("analyses__performances").all()
1016+
1017+
for p in patterns:
1018+
for a in p.analyses:
1019+
print(a.performances) # safe
1020+
"#;
1021+
let diags = run_pass(source);
1022+
assert!(diags.is_empty());
1023+
}
1024+
9891025
#[test]
9901026
fn second_level_related_access() {
9911027
let source = r#"

0 commit comments

Comments
 (0)