Skip to content

Commit 940e130

Browse files
Merge pull request #20793 from A4-Tacks/diag-paren-missing-unsafe
Fix missing parentheses for missing_unsafe
2 parents 4ae99f0 + a068ef8 commit 940e130

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

crates/ide-diagnostics/src/handlers/missing_unsafe.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option<Vec<Ass
5050

5151
let node_to_add_unsafe_block = pick_best_node_to_add_unsafe_block(&expr)?;
5252

53-
let replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text());
53+
let mut replacement = format!("unsafe {{ {} }}", node_to_add_unsafe_block.text());
54+
if let Some(expr) = ast::Expr::cast(node_to_add_unsafe_block.clone())
55+
&& needs_parentheses(&expr)
56+
{
57+
replacement = format!("({replacement})");
58+
}
5459
let edit = TextEdit::replace(node_to_add_unsafe_block.text_range(), replacement);
5560
let source_change = SourceChange::from_text_edit(
5661
d.node.file_id.original_file(ctx.sema.db).file_id(ctx.sema.db),
@@ -112,6 +117,17 @@ fn pick_best_node_to_add_unsafe_block(unsafe_expr: &ast::Expr) -> Option<SyntaxN
112117
None
113118
}
114119

120+
fn needs_parentheses(expr: &ast::Expr) -> bool {
121+
let node = expr.syntax();
122+
node.ancestors()
123+
.skip(1)
124+
.take_while(|it| it.text_range().start() == node.text_range().start())
125+
.map_while(ast::Expr::cast)
126+
.last()
127+
.and_then(|it| Some(it.syntax().parent()?.kind()))
128+
.is_some_and(|kind| ast::ExprStmt::can_cast(kind) || ast::StmtList::can_cast(kind))
129+
}
130+
115131
#[cfg(test)]
116132
mod tests {
117133
use crate::tests::{check_diagnostics, check_fix, check_no_fix};
@@ -570,6 +586,27 @@ fn main() {
570586
)
571587
}
572588

589+
#[test]
590+
fn needs_parentheses_for_unambiguous() {
591+
check_fix(
592+
r#"
593+
//- minicore: copy
594+
static mut STATIC_MUT: u8 = 0;
595+
596+
fn foo() -> u8 {
597+
STATIC_MUT$0 * 2
598+
}
599+
"#,
600+
r#"
601+
static mut STATIC_MUT: u8 = 0;
602+
603+
fn foo() -> u8 {
604+
(unsafe { STATIC_MUT }) * 2
605+
}
606+
"#,
607+
)
608+
}
609+
573610
#[test]
574611
fn ref_to_unsafe_expr() {
575612
check_fix(

0 commit comments

Comments
 (0)