Skip to content

Commit a5f01ee

Browse files
committed
feat: enhance cwd resolution and model duplicates
Be permissive when access to a field when has a duplicated model from another app, to handle those cases we need to track each app separately increasing the complexity. For now is good enough to be permissive when two models has the same name but in different apps or modules. Avoid considering `Base` as a django model base model, that is wrong because `Base` is a common name for a bunch of applications and can override the original django model incorrectly. Aditionally prioritize cwd using `manage.py` as a marker.
1 parent 527c413 commit a5f01ee

File tree

8 files changed

+249
-38
lines changed

8 files changed

+249
-38
lines changed

.github/workflows/pypi-release.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,6 @@ jobs:
137137
runs-on: ubuntu-latest
138138
steps:
139139
- uses: actions/checkout@v6
140-
- name: Replace symlinks with actual files
141-
run: |
142-
cd crates/django-check
143-
rm -f LICENSE README.md
144-
cp ../../LICENSE LICENSE
145-
cp ../../README.md README.md
146140
- name: Build sdist
147141
uses: PyO3/maturin-action@v1
148142
with:

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/django-check/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "django-check"
3-
version = "0.1.7"
3+
version = "0.1.8"
44
rust-version = { workspace = true }
55
license = { workspace = true }
66
authors = { workspace = true }

crates/django-check/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "maturin"
44

55
[project]
66
name = "djch"
7-
version = "0.1.7"
7+
version = "0.1.8"
88
description = "Static N+1 query detector for Django ORM – fast Rust-powered checker"
99
# For PyPI long description (shows formatted on project page)
1010
readme = {file = "README.md", content-type = "text/markdown"}

crates/django-check/src/main.rs

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mod cli;
2-
use std::path::Path;
2+
use std::path::{Path, PathBuf};
33

44
use django_check_semantic::Parser;
55
use django_check_server::serve;
@@ -12,29 +12,97 @@ use tracing_subscriber::{EnvFilter, filter::LevelFilter};
1212
#[tokio::main]
1313
async fn main() -> Result<(), Box<dyn std::error::Error>> {
1414
let cwd = current_dir().expect("should get the cwd");
15+
let root = resolve_project_root(&cwd);
1516
let parser = Parser::new();
1617

17-
let model_graph = parser.extract_model_graph(&cwd)?;
18-
let functions = parser.extract_functions(&cwd)?;
18+
let model_graph = parser.extract_model_graph(&root)?;
19+
let functions = parser.extract_functions(&root)?;
1920

2021
let cli = Cli::parse();
2122

2223
match cli.cmd {
2324
Cmd::Check => {
2425
initialize_logger(false);
25-
if let Err(e) = parser.analyze_directory(&cwd, &model_graph, &functions) {
26+
if let Err(e) = parser.analyze_directory(&root, &model_graph, &functions) {
2627
eprintln!("Error: {}", e);
2728
}
2829
}
2930
Cmd::Server => {
3031
initialize_logger(true);
31-
serve(&cwd, model_graph, functions).await;
32+
serve(&root, model_graph, functions).await;
3233
}
3334
}
3435

3536
Ok(())
3637
}
3738

39+
fn resolve_project_root(cwd: &Path) -> PathBuf {
40+
// Prefer the current directory when it already looks like a Django project.
41+
if cwd.join("manage.py").is_file() {
42+
return cwd.to_path_buf();
43+
}
44+
45+
// If we're one level above the app folder, pick that folder automatically.
46+
let nested_candidates = std::fs::read_dir(cwd)
47+
.ok()
48+
.into_iter()
49+
.flat_map(|entries| entries.filter_map(Result::ok))
50+
.map(|entry| entry.path())
51+
.filter(|path| path.is_dir() && path.join("manage.py").is_file())
52+
.collect::<Vec<_>>();
53+
54+
if nested_candidates.len() == 1 {
55+
return nested_candidates[0].clone();
56+
}
57+
58+
cwd.to_path_buf()
59+
}
60+
61+
#[cfg(test)]
62+
mod tests {
63+
use super::resolve_project_root;
64+
use std::path::Path;
65+
use std::time::{SystemTime, UNIX_EPOCH};
66+
67+
fn unique_suffix() -> u128 {
68+
SystemTime::now()
69+
.duration_since(UNIX_EPOCH)
70+
.expect("clock should be after unix epoch")
71+
.as_nanos()
72+
}
73+
74+
#[test]
75+
fn prefers_current_directory_when_manage_py_exists() {
76+
let base = std::env::temp_dir().join(format!(
77+
"djch-root-current-{}-{}",
78+
std::process::id(),
79+
unique_suffix()
80+
));
81+
std::fs::create_dir_all(&base).expect("create temp dir");
82+
std::fs::write(base.join("manage.py"), "").expect("create manage.py");
83+
84+
assert_eq!(resolve_project_root(&base), base);
85+
86+
std::fs::remove_dir_all(&base).expect("cleanup temp dir");
87+
}
88+
89+
#[test]
90+
fn picks_single_nested_django_directory() {
91+
let base = std::env::temp_dir().join(format!(
92+
"djch-root-nested-{}-{}",
93+
std::process::id(),
94+
unique_suffix()
95+
));
96+
let app = base.join("app");
97+
std::fs::create_dir_all(&app).expect("create app dir");
98+
std::fs::write(app.join("manage.py"), "").expect("create nested manage.py");
99+
100+
assert_eq!(resolve_project_root(Path::new(&base)), app);
101+
102+
std::fs::remove_dir_all(&base).expect("cleanup temp dir");
103+
}
104+
}
105+
38106
fn initialize_logger(with_file: bool) {
39107
let env_filter = EnvFilter::from_default_env();
40108

crates/django-check_semantic/src/parser.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,12 @@ impl Parser {
3636
fn analyze_n_plus_one_file(
3737
&self,
3838
file: &Path,
39+
display_file: &str,
3940
model_graph: &ModelGraph,
4041
functions: &[QueryFunction],
4142
) -> Result<Vec<NPlusOneDiagnostic>, SourceParseError> {
4243
let source = fs::read_to_string(file)?;
43-
self.analyze_source(
44-
&source,
45-
file.to_str().expect("conver to str"),
46-
model_graph,
47-
functions,
48-
)
44+
self.analyze_source(&source, display_file, model_graph, functions)
4945
}
5046

5147
/// Run N+1 detection on a file or a directory, this function figure out
@@ -61,11 +57,10 @@ impl Parser {
6157

6258
if path.is_dir() {
6359
for entry in Self::python_files(path) {
60+
let display_path = Self::relative_path(path, entry.path());
6461
let diagnostics = match self.analyze_n_plus_one_file(
65-
entry
66-
.path()
67-
.strip_prefix(path)
68-
.expect("path derived from prefix"),
62+
entry.path(),
63+
&display_path,
6964
model_graph,
7065
functions,
7166
) {
@@ -78,7 +73,9 @@ impl Parser {
7873
all_diagnostics.extend(diagnostics);
7974
}
8075
} else if path.is_file() {
81-
all_diagnostics = self.analyze_n_plus_one_file(path, model_graph, functions)?;
76+
let display_path = path.to_string_lossy().to_string();
77+
all_diagnostics =
78+
self.analyze_n_plus_one_file(path, &display_path, model_graph, functions)?;
8279
}
8380

8481
Ok(all_diagnostics)
@@ -167,7 +164,8 @@ impl Parser {
167164
model_graph: &ModelGraph,
168165
functions: &[QueryFunction],
169166
) -> Result<Vec<NPlusOneDiagnostic>, SourceParseError> {
170-
self.analyze_n_plus_one_file(file, model_graph, functions)
167+
let display_path = file.to_string_lossy().to_string();
168+
self.analyze_n_plus_one_file(file, &display_path, model_graph, functions)
171169
}
172170

173171
/// Run all analyses in a source code and return diagnostics

crates/django-check_semantic/src/passes/model_graph.rs

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ impl<'a> ModelGraphPass<'a> {
7272
// Direct Model import or custom base class
7373
Expr::Name(name) => {
7474
let n = name.id.as_str();
75-
// Common patterns: Model, BaseModel, AbstractModel, etc.
76-
if n == "Model" || n.ends_with("Model") || n.ends_with("Base") {
75+
// Common Django patterns: Model, BaseModel, AbstractModel, etc.
76+
// Avoid generic `Base` classes (e.g. SQLAlchemy declarative base).
77+
if n == "Model" || n.ends_with("Model") {
7778
return true;
7879
}
7980
}
@@ -226,16 +227,30 @@ fn extract_related_name(call: &ruff_python_ast::ExprCall) -> Option<String> {
226227

227228
/// Analyze an Stmt and capture the relation if it exists
228229
fn extract_relation_from_stmt(stmt: &Stmt, model_name: &str) -> Option<Relation> {
229-
let Stmt::Assign(assign) = stmt else {
230-
return None;
231-
};
232-
233-
let Some(Expr::Name(target)) = assign.targets.first() else {
234-
return None;
235-
};
236-
let field_name = target.id.to_string();
230+
match stmt {
231+
Stmt::Assign(assign) => {
232+
let Some(Expr::Name(target)) = assign.targets.first() else {
233+
return None;
234+
};
235+
extract_relation_from_value(&target.id, assign.value.as_ref(), model_name)
236+
}
237+
Stmt::AnnAssign(assign) => {
238+
let Expr::Name(target) = assign.target.as_ref() else {
239+
return None;
240+
};
241+
let value = assign.value.as_ref()?;
242+
extract_relation_from_value(&target.id, value, model_name)
243+
}
244+
_ => None,
245+
}
246+
}
237247

238-
let Expr::Call(call) = assign.value.as_ref() else {
248+
fn extract_relation_from_value(
249+
field_name: &str,
250+
value: &Expr,
251+
model_name: &str,
252+
) -> Option<Relation> {
253+
let Expr::Call(call) = value else {
239254
return None;
240255
};
241256

@@ -245,7 +260,7 @@ fn extract_relation_from_stmt(stmt: &Stmt, model_name: &str) -> Option<Relation>
245260

246261
Some(Relation::new(
247262
model_name,
248-
field_name,
263+
field_name.to_string(),
249264
target_model,
250265
relation_type,
251266
related_name,
@@ -382,6 +397,21 @@ class Article(models.Model):
382397
assert_eq!(article.relations[1].relation_type, RelationType::ManyToMany);
383398
}
384399

400+
#[test]
401+
fn extract_annotated_relation_fields() {
402+
let source = r#"
403+
class Profile(models.Model):
404+
user: User = models.ForeignKey(User, on_delete=models.CASCADE)
405+
"#;
406+
let graph = run_pass(source);
407+
let profile = graph.get("Profile").expect("Profile should exist");
408+
409+
assert_eq!(profile.relations.len(), 1);
410+
assert_eq!(profile.relations[0].field_name, "user");
411+
assert_eq!(profile.relations[0].target_model, "User");
412+
assert_eq!(profile.relations[0].relation_type, RelationType::ForeignKey);
413+
}
414+
385415
#[test]
386416
fn dependents_query() {
387417
let source = r#"
@@ -513,4 +543,18 @@ class CreatedOrder(Order):
513543
&& r.related_name("CreatedOrder") == "createdorders")
514544
);
515545
}
546+
547+
#[test]
548+
fn doesnt_extract_sqlalchemy_base_models() {
549+
let source = r#"
550+
from sqlalchemy.ext.declarative import declarative_base
551+
552+
Base = declarative_base()
553+
554+
class User(Base):
555+
pass
556+
"#;
557+
let graph = run_pass(source);
558+
assert_eq!(graph.model_count(), 0);
559+
}
516560
}

0 commit comments

Comments
 (0)