Skip to content

Commit e5fcee1

Browse files
committed
Guard HIR lowered contracts with contract_checks
Refactor contract HIR lowering to ensure no contract code is executed when contract-checks are disabled. The call to contract_checks is moved to inside the lowered fn body, and contract closures are built conditionally, ensuring no side-effects present in contracts occur when those are disabled.
1 parent ae12bc2 commit e5fcee1

File tree

18 files changed

+459
-120
lines changed

18 files changed

+459
-120
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
use crate::LoweringContext;
2+
3+
impl<'a, 'hir> LoweringContext<'a, 'hir> {
4+
pub(super) fn lower_contract(
5+
&mut self,
6+
body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
7+
contract: &rustc_ast::FnContract,
8+
) -> rustc_hir::Expr<'hir> {
9+
match (&contract.requires, &contract.ensures) {
10+
(Some(req), Some(ens)) => {
11+
// Lower the fn contract, which turns:
12+
//
13+
// { body }
14+
//
15+
// into:
16+
//
17+
// {
18+
// let __postcond = if contracts_checks() {
19+
// contract_check_requires(PRECOND);
20+
// Some(|ret_val| POSTCOND)
21+
// } else {
22+
// None
23+
// };
24+
// contract_check_ensures(__postcond, { body })
25+
// }
26+
27+
let precond = self.lower_precond(req);
28+
let postcond_checker = self.lower_postcond_checker(ens);
29+
30+
let contract_check =
31+
self.lower_contract_check_with_postcond(Some(precond), postcond_checker);
32+
33+
let wrapped_body =
34+
self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
35+
self.expr_block(wrapped_body)
36+
}
37+
(None, Some(ens)) => {
38+
// Lower the fn contract, which turns:
39+
//
40+
// { body }
41+
//
42+
// into:
43+
//
44+
// {
45+
// let __postcond = if contracts_check() {
46+
// Some(|ret_val| POSTCOND)
47+
// } else {
48+
// None
49+
// };
50+
// __postcond({ body })
51+
// }
52+
53+
let postcond_checker = self.lower_postcond_checker(ens);
54+
let contract_check =
55+
self.lower_contract_check_with_postcond(None, postcond_checker);
56+
57+
let wrapped_body =
58+
self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
59+
self.expr_block(wrapped_body)
60+
}
61+
(Some(req), None) => {
62+
// Lower the fn contract, which turns:
63+
//
64+
// { body }
65+
//
66+
// into:
67+
//
68+
// {
69+
// if contracts_check() {
70+
// contract_requires(PRECOND);
71+
// }
72+
// body
73+
// }
74+
let precond = self.lower_precond(req);
75+
let precond_check = self.lower_contract_check_just_precond(precond);
76+
77+
let body = self.arena.alloc(body(self));
78+
79+
// Flatten the body into precond check, then body.
80+
let wrapped_body = self.block_all(
81+
body.span,
82+
self.arena.alloc_from_iter([precond_check].into_iter()),
83+
Some(body),
84+
);
85+
self.expr_block(wrapped_body)
86+
}
87+
(None, None) => body(self),
88+
}
89+
}
90+
91+
/// Lower the precondition check intrinsic.
92+
fn lower_precond(&mut self, req: &Box<rustc_ast::Expr>) -> rustc_hir::Stmt<'hir> {
93+
let lowered_req = self.lower_expr_mut(&req);
94+
let req_span = self.mark_span_with_reason(
95+
rustc_span::DesugaringKind::Contract,
96+
lowered_req.span,
97+
None,
98+
);
99+
let precond = self.expr_call_lang_item_fn_mut(
100+
req_span,
101+
rustc_hir::LangItem::ContractCheckRequires,
102+
&*arena_vec![self; lowered_req],
103+
);
104+
self.stmt_expr(req.span, precond)
105+
}
106+
107+
fn lower_postcond_checker(
108+
&mut self,
109+
ens: &Box<rustc_ast::Expr>,
110+
) -> &'hir rustc_hir::Expr<'hir> {
111+
let ens_span = self.lower_span(ens.span);
112+
let ens_span =
113+
self.mark_span_with_reason(rustc_span::DesugaringKind::Contract, ens_span, None);
114+
let lowered_ens = self.lower_expr_mut(&ens);
115+
self.expr_call_lang_item_fn(
116+
ens_span,
117+
rustc_hir::LangItem::ContractBuildCheckEnsures,
118+
&*arena_vec![self; lowered_ens],
119+
)
120+
}
121+
122+
fn lower_contract_check_just_precond(
123+
&mut self,
124+
precond: rustc_hir::Stmt<'hir>,
125+
) -> rustc_hir::Stmt<'hir> {
126+
let stmts = self.arena.alloc_from_iter([precond].into_iter());
127+
128+
let then_block_stmts = self.block_all(precond.span, stmts, None);
129+
let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
130+
131+
let precond_check = rustc_hir::ExprKind::If(
132+
self.expr_call_lang_item_fn(
133+
precond.span,
134+
rustc_hir::LangItem::ContractChecks,
135+
Default::default(),
136+
),
137+
then_block,
138+
None,
139+
);
140+
141+
let precond_check = self.expr(precond.span, precond_check);
142+
self.stmt_expr(precond.span, precond_check)
143+
}
144+
145+
fn lower_contract_check_with_postcond(
146+
&mut self,
147+
precond: Option<rustc_hir::Stmt<'hir>>,
148+
postcond_checker: &'hir rustc_hir::Expr<'hir>,
149+
) -> &'hir rustc_hir::Expr<'hir> {
150+
let stmts = self.arena.alloc_from_iter(precond.into_iter());
151+
let span = match precond {
152+
Some(precond) => precond.span,
153+
None => postcond_checker.span,
154+
};
155+
156+
let postcond_checker = self.arena.alloc(self.expr_enum_variant_lang_item(
157+
postcond_checker.span,
158+
rustc_hir::lang_items::LangItem::OptionSome,
159+
&*arena_vec![self; *postcond_checker],
160+
));
161+
let then_block_stmts = self.block_all(span, stmts, Some(postcond_checker));
162+
let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
163+
164+
let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
165+
postcond_checker.span,
166+
rustc_hir::lang_items::LangItem::OptionNone,
167+
Default::default(),
168+
));
169+
let else_block = self.block_expr(none_expr);
170+
let else_block = self.arena.alloc(self.expr_block(else_block));
171+
172+
let contract_check = rustc_hir::ExprKind::If(
173+
self.expr_call_lang_item_fn(
174+
span,
175+
rustc_hir::LangItem::ContractChecks,
176+
Default::default(),
177+
),
178+
then_block,
179+
Some(else_block),
180+
);
181+
self.arena.alloc(self.expr(span, contract_check))
182+
}
183+
184+
fn wrap_body_with_contract_check(
185+
&mut self,
186+
body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
187+
contract_check: &'hir rustc_hir::Expr<'hir>,
188+
postcond_span: rustc_span::Span,
189+
) -> &'hir rustc_hir::Block<'hir> {
190+
let check_ident: rustc_span::Ident =
191+
rustc_span::Ident::from_str_and_span("__ensures_checker", postcond_span);
192+
let (check_hir_id, postcond_decl) = {
193+
// Set up the postcondition `let` statement.
194+
let (checker_pat, check_hir_id) = self.pat_ident_binding_mode_mut(
195+
postcond_span,
196+
check_ident,
197+
rustc_hir::BindingMode::NONE,
198+
);
199+
(
200+
check_hir_id,
201+
self.stmt_let_pat(
202+
None,
203+
postcond_span,
204+
Some(contract_check),
205+
self.arena.alloc(checker_pat),
206+
rustc_hir::LocalSource::Contract,
207+
),
208+
)
209+
};
210+
211+
// Install contract_ensures so we will intercept `return` statements,
212+
// then lower the body.
213+
self.contract_ensures = Some((postcond_span, check_ident, check_hir_id));
214+
let body = self.arena.alloc(body(self));
215+
216+
// Finally, inject an ensures check on the implicit return of the body.
217+
let body = self.inject_ensures_check(body, postcond_span, check_ident, check_hir_id);
218+
219+
// Flatten the body into precond, then postcond, then wrapped body.
220+
let wrapped_body = self.block_all(
221+
body.span,
222+
self.arena.alloc_from_iter([postcond_decl].into_iter()),
223+
Some(body),
224+
);
225+
wrapped_body
226+
}
227+
228+
/// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
229+
/// a contract ensures clause, if it exists.
230+
pub(super) fn checked_return(
231+
&mut self,
232+
opt_expr: Option<&'hir rustc_hir::Expr<'hir>>,
233+
) -> rustc_hir::ExprKind<'hir> {
234+
let checked_ret =
235+
if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
236+
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
237+
Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
238+
} else {
239+
opt_expr
240+
};
241+
rustc_hir::ExprKind::Ret(checked_ret)
242+
}
243+
244+
/// Wraps an expression with a call to the ensures check before it gets returned.
245+
pub(super) fn inject_ensures_check(
246+
&mut self,
247+
expr: &'hir rustc_hir::Expr<'hir>,
248+
span: rustc_span::Span,
249+
cond_ident: rustc_span::Ident,
250+
cond_hir_id: rustc_hir::HirId,
251+
) -> &'hir rustc_hir::Expr<'hir> {
252+
let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
253+
let call_expr = self.expr_call_lang_item_fn_mut(
254+
span,
255+
rustc_hir::LangItem::ContractCheckEnsures,
256+
arena_vec![self; *cond_fn, *expr],
257+
);
258+
self.arena.alloc(call_expr)
259+
}
260+
}

compiler/rustc_ast_lowering/src/expr.rs

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -383,36 +383,6 @@ impl<'hir> LoweringContext<'_, 'hir> {
383383
})
384384
}
385385

386-
/// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
387-
/// a contract ensures clause, if it exists.
388-
fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> {
389-
let checked_ret =
390-
if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
391-
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
392-
Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
393-
} else {
394-
opt_expr
395-
};
396-
hir::ExprKind::Ret(checked_ret)
397-
}
398-
399-
/// Wraps an expression with a call to the ensures check before it gets returned.
400-
pub(crate) fn inject_ensures_check(
401-
&mut self,
402-
expr: &'hir hir::Expr<'hir>,
403-
span: Span,
404-
cond_ident: Ident,
405-
cond_hir_id: HirId,
406-
) -> &'hir hir::Expr<'hir> {
407-
let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
408-
let call_expr = self.expr_call_lang_item_fn_mut(
409-
span,
410-
hir::LangItem::ContractCheckEnsures,
411-
arena_vec![self; *cond_fn, *expr],
412-
);
413-
self.arena.alloc(call_expr)
414-
}
415-
416386
pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
417387
self.with_new_scopes(c.value.span, |this| {
418388
let def_id = this.local_def_id(c.id);
@@ -2120,7 +2090,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
21202090
self.expr(span, hir::ExprKind::AddrOf(hir::BorrowKind::Ref, hir::Mutability::Mut, e))
21212091
}
21222092

2123-
fn expr_unit(&mut self, sp: Span) -> &'hir hir::Expr<'hir> {
2093+
pub(super) fn expr_unit(&mut self, sp: Span) -> &'hir hir::Expr<'hir> {
21242094
self.arena.alloc(self.expr(sp, hir::ExprKind::Tup(&[])))
21252095
}
21262096

@@ -2161,6 +2131,43 @@ impl<'hir> LoweringContext<'_, 'hir> {
21612131
self.expr(span, hir::ExprKind::Call(e, args))
21622132
}
21632133

2134+
pub(super) fn expr_struct(
2135+
&mut self,
2136+
span: Span,
2137+
path: &'hir hir::QPath<'hir>,
2138+
fields: &'hir [hir::ExprField<'hir>],
2139+
) -> hir::Expr<'hir> {
2140+
self.expr(span, hir::ExprKind::Struct(path, fields, rustc_hir::StructTailExpr::None))
2141+
}
2142+
2143+
pub(super) fn expr_enum_variant(
2144+
&mut self,
2145+
span: Span,
2146+
path: &'hir hir::QPath<'hir>,
2147+
fields: &'hir [hir::Expr<'hir>],
2148+
) -> hir::Expr<'hir> {
2149+
let fields = self.arena.alloc_from_iter(fields.into_iter().enumerate().map(|(i, f)| {
2150+
hir::ExprField {
2151+
hir_id: self.next_id(),
2152+
ident: Ident::from_str(&i.to_string()),
2153+
expr: f,
2154+
span: f.span,
2155+
is_shorthand: false,
2156+
}
2157+
}));
2158+
self.expr_struct(span, path, fields)
2159+
}
2160+
2161+
pub(super) fn expr_enum_variant_lang_item(
2162+
&mut self,
2163+
span: Span,
2164+
lang_item: hir::LangItem,
2165+
fields: &'hir [hir::Expr<'hir>],
2166+
) -> hir::Expr<'hir> {
2167+
let path = self.arena.alloc(self.lang_item_path(span, lang_item));
2168+
self.expr_enum_variant(span, path, fields)
2169+
}
2170+
21642171
pub(super) fn expr_call(
21652172
&mut self,
21662173
span: Span,
@@ -2189,8 +2196,21 @@ impl<'hir> LoweringContext<'_, 'hir> {
21892196
self.arena.alloc(self.expr_call_lang_item_fn_mut(span, lang_item, args))
21902197
}
21912198

2192-
fn expr_lang_item_path(&mut self, span: Span, lang_item: hir::LangItem) -> hir::Expr<'hir> {
2193-
self.expr(span, hir::ExprKind::Path(hir::QPath::LangItem(lang_item, self.lower_span(span))))
2199+
pub(super) fn expr_lang_item_path(
2200+
&mut self,
2201+
span: Span,
2202+
lang_item: hir::LangItem,
2203+
) -> hir::Expr<'hir> {
2204+
let path = self.lang_item_path(span, lang_item);
2205+
self.expr(span, hir::ExprKind::Path(path))
2206+
}
2207+
2208+
pub(super) fn lang_item_path(
2209+
&mut self,
2210+
span: Span,
2211+
lang_item: hir::LangItem,
2212+
) -> hir::QPath<'hir> {
2213+
hir::QPath::LangItem(lang_item, self.lower_span(span))
21942214
}
21952215

21962216
/// `<LangItem>::name`

0 commit comments

Comments
 (0)