Skip to content

Commit a969bf9

Browse files
committed
Extend the enum check to pointer and union reads
This change extends the previously added enum discriminant check to enums read through a union or pointer. At the moment we only insert the check when transmuting to an enum. Although I hoped for it, this check isn't yet inserted for calls to `MaybeUninit::assume_init`, because the pass is running on polymorphic MIR and thus doesn't have the information yet to know whether the type that is read is an enum.
1 parent 90b6588 commit a969bf9

File tree

10 files changed

+366
-65
lines changed

10 files changed

+366
-65
lines changed

compiler/rustc_mir_transform/src/check_enums.rs

Lines changed: 211 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
1+
use rustc_abi::{HasDataLayout, Scalar, Size, TagEncoding, Variants, WrappingRange};
22
use rustc_hir::LangItem;
33
use rustc_index::IndexVec;
44
use rustc_middle::bug;
5-
use rustc_middle::mir::visit::Visitor;
5+
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
66
use rustc_middle::mir::*;
7-
use rustc_middle::ty::layout::PrimitiveExt;
8-
use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
7+
use rustc_middle::ty::layout::{IntegerExt, PrimitiveExt};
8+
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypingEnv};
99
use rustc_session::Session;
1010
use tracing::debug;
1111

@@ -163,79 +163,202 @@ impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
163163
fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
164164
self.enums
165165
}
166+
167+
/// Registers a new enum check in the finder.
168+
fn register_new_check(
169+
&mut self,
170+
enum_ty: Ty<'tcx>,
171+
enum_def: AdtDef<'tcx>,
172+
source_op: Operand<'tcx>,
173+
) {
174+
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(enum_ty)) else {
175+
return;
176+
};
177+
// If the operand is a pointer, we want to pass on the size of the operand to the check,
178+
// as we will dereference the pointer and look at the value directly.
179+
let Ok(op_layout) = (if let ty::RawPtr(pointee_ty, _) =
180+
source_op.ty(self.local_decls, self.tcx).kind()
181+
{
182+
self.tcx.layout_of(self.typing_env.as_query_input(*pointee_ty))
183+
} else {
184+
self.tcx
185+
.layout_of(self.typing_env.as_query_input(source_op.ty(self.local_decls, self.tcx)))
186+
}) else {
187+
return;
188+
};
189+
190+
match enum_layout.variants {
191+
Variants::Empty if op_layout.is_uninhabited() => return,
192+
// An empty enum that tries to be constructed from an inhabited value, this
193+
// is never correct.
194+
Variants::Empty => {
195+
// The enum layout is uninhabited but we construct it from sth inhabited.
196+
// This is always UB.
197+
self.enums.push(EnumCheckType::Uninhabited);
198+
}
199+
// Construction of Single value enums is always fine.
200+
Variants::Single { .. } => {}
201+
// Construction of an enum with multiple variants but no niche optimizations.
202+
Variants::Multiple {
203+
tag_encoding: TagEncoding::Direct,
204+
tag: Scalar::Initialized { value, .. },
205+
..
206+
} => {
207+
let valid_discrs =
208+
enum_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
209+
210+
let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
211+
self.enums.push(EnumCheckType::Direct {
212+
source_op: source_op.to_copy(),
213+
discr,
214+
op_size: op_layout.size,
215+
valid_discrs,
216+
});
217+
}
218+
// Construction of an enum with multiple variants and niche optimizations.
219+
Variants::Multiple {
220+
tag_encoding: TagEncoding::Niche { .. },
221+
tag: Scalar::Initialized { value, valid_range, .. },
222+
tag_field,
223+
..
224+
} => {
225+
let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
226+
self.enums.push(EnumCheckType::WithNiche {
227+
source_op: source_op.to_copy(),
228+
discr,
229+
op_size: op_layout.size,
230+
offset: enum_layout.fields.offset(tag_field.as_usize()),
231+
valid_range,
232+
});
233+
}
234+
_ => return,
235+
}
236+
}
166237
}
167238

168239
impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
169-
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
170-
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
171-
let ty::Adt(adt_def, _) = ty.kind() else {
172-
return;
173-
};
174-
if !adt_def.is_enum() {
240+
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
241+
self.super_place(place, context, location);
242+
// We only want to emit this check on pointer reads.
243+
match context {
244+
PlaceContext::NonMutatingUse(
245+
NonMutatingUseContext::Copy
246+
| NonMutatingUseContext::Move
247+
| NonMutatingUseContext::SharedBorrow,
248+
) => {}
249+
_ => {
175250
return;
176251
}
252+
}
177253

178-
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
254+
if !place.is_indirect() {
255+
return;
256+
}
257+
// Get the place and type we visit.
258+
let pointer = Place::from(place.local);
259+
let pointer_ty = pointer.ty(self.local_decls, self.tcx).ty;
260+
261+
// We only want to check places based on raw pointers to enums or ManuallyDrop<Enum>.
262+
let &ty::RawPtr(pointee_ty, _) = pointer_ty.kind() else {
263+
return;
264+
};
265+
let ty::Adt(enum_adt_def, _) = pointee_ty.kind() else {
266+
return;
267+
};
268+
269+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
270+
(pointee_ty, enum_adt_def)
271+
} else if enum_adt_def.is_manually_drop() {
272+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
273+
let Some((manual_drop_arg, adt_def)) =
274+
pointee_ty.walk().skip(1).next().map_or(None, |arg| {
275+
if let Some(ty) = arg.as_type()
276+
&& let ty::Adt(adt_def, _) = ty.kind()
277+
{
278+
Some((ty, adt_def))
279+
} else {
280+
None
281+
}
282+
})
283+
else {
179284
return;
180285
};
181-
let Ok(op_layout) = self
182-
.tcx
183-
.layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
286+
287+
(manual_drop_arg, adt_def)
288+
} else {
289+
return;
290+
};
291+
// Exclude c_void.
292+
if enum_ty.is_c_void(self.tcx) {
293+
return;
294+
}
295+
296+
self.register_new_check(enum_ty, *enum_adt_def, Operand::Copy(*place));
297+
}
298+
299+
fn visit_projection_elem(
300+
&mut self,
301+
place_ref: PlaceRef<'tcx>,
302+
elem: PlaceElem<'tcx>,
303+
context: visit::PlaceContext,
304+
location: Location,
305+
) {
306+
self.super_projection_elem(place_ref, elem, context, location);
307+
// Check whether we are reading an enum or a ManuallyDrop<Enum> from a union.
308+
let ty::Adt(union_adt_def, _) = place_ref.ty(self.local_decls, self.tcx).ty.kind() else {
309+
return;
310+
};
311+
if !union_adt_def.is_union() {
312+
return;
313+
}
314+
let PlaceElem::Field(_, extracted_ty) = elem else {
315+
return;
316+
};
317+
let ty::Adt(enum_adt_def, _) = extracted_ty.kind() else {
318+
return;
319+
};
320+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
321+
(extracted_ty, enum_adt_def)
322+
} else if enum_adt_def.is_manually_drop() {
323+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
324+
let Some((manual_drop_arg, adt_def)) =
325+
extracted_ty.walk().skip(1).next().map_or(None, |arg| {
326+
if let Some(ty) = arg.as_type()
327+
&& let ty::Adt(adt_def, _) = ty.kind()
328+
{
329+
Some((ty, adt_def))
330+
} else {
331+
None
332+
}
333+
})
184334
else {
185335
return;
186336
};
187337

188-
match enum_layout.variants {
189-
Variants::Empty if op_layout.is_uninhabited() => return,
190-
// An empty enum that tries to be constructed from an inhabited value, this
191-
// is never correct.
192-
Variants::Empty => {
193-
// The enum layout is uninhabited but we construct it from sth inhabited.
194-
// This is always UB.
195-
self.enums.push(EnumCheckType::Uninhabited);
196-
}
197-
// Construction of Single value enums is always fine.
198-
Variants::Single { .. } => {}
199-
// Construction of an enum with multiple variants but no niche optimizations.
200-
Variants::Multiple {
201-
tag_encoding: TagEncoding::Direct,
202-
tag: Scalar::Initialized { value, .. },
203-
..
204-
} => {
205-
let valid_discrs =
206-
adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
207-
208-
let discr =
209-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
210-
self.enums.push(EnumCheckType::Direct {
211-
source_op: op.to_copy(),
212-
discr,
213-
op_size: op_layout.size,
214-
valid_discrs,
215-
});
216-
}
217-
// Construction of an enum with multiple variants and niche optimizations.
218-
Variants::Multiple {
219-
tag_encoding: TagEncoding::Niche { .. },
220-
tag: Scalar::Initialized { value, valid_range, .. },
221-
tag_field,
222-
..
223-
} => {
224-
let discr =
225-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
226-
self.enums.push(EnumCheckType::WithNiche {
227-
source_op: op.to_copy(),
228-
discr,
229-
op_size: op_layout.size,
230-
offset: enum_layout.fields.offset(tag_field.as_usize()),
231-
valid_range,
232-
});
233-
}
234-
_ => return,
338+
(manual_drop_arg, adt_def)
339+
} else {
340+
return;
341+
};
342+
343+
self.register_new_check(
344+
enum_ty,
345+
*enum_adt_def,
346+
Operand::Copy(place_ref.to_place(self.tcx)),
347+
);
348+
}
349+
350+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
351+
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
352+
let ty::Adt(adt_def, _) = ty.kind() else {
353+
return;
354+
};
355+
if !adt_def.is_enum() {
356+
return;
235357
}
236358

237-
self.super_rvalue(rvalue, location);
359+
self.register_new_check(*ty, *adt_def, op.to_copy());
238360
}
361+
self.super_rvalue(rvalue, location);
239362
}
240363
}
241364

@@ -261,7 +384,7 @@ fn insert_discr_cast_to_u128<'tcx>(
261384
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
262385
block_data: &mut BasicBlockData<'tcx>,
263386
source_op: Operand<'tcx>,
264-
discr: TyAndSize<'tcx>,
387+
mut discr: TyAndSize<'tcx>,
265388
op_size: Size,
266389
offset: Option<Size>,
267390
source_info: SourceInfo,
@@ -277,6 +400,29 @@ fn insert_discr_cast_to_u128<'tcx>(
277400
}
278401
};
279402

403+
// If the enum is behind a pointer, cast it to a *[const|mut] MaybeUninit<T> and then extract the discriminant through that.
404+
let source_op = if let ty::RawPtr(pointee_ty, mutbl) = source_op.ty(local_decls, tcx).kind()
405+
&& !discr.ty.is_raw_ptr()
406+
{
407+
let mu_ptr_ty = Ty::new_ptr(tcx, Ty::new_maybe_uninit(tcx, *pointee_ty), *mutbl);
408+
let mu_ptr_decl =
409+
local_decls.push(LocalDecl::with_source_info(mu_ptr_ty, source_info)).into();
410+
let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_ptr_ty);
411+
block_data.statements.push(Statement::new(
412+
source_info,
413+
StatementKind::Assign(Box::new((mu_ptr_decl, rvalue))),
414+
));
415+
416+
Operand::Copy(mu_ptr_decl.project_deeper(&[ProjectionElem::Deref], tcx))
417+
} else {
418+
source_op
419+
};
420+
421+
// Correct the discriminant ty to an integer, to not screw up our casts to the discriminant ty.
422+
if discr.ty.is_raw_ptr() {
423+
discr.ty = tcx.data_layout().ptr_sized_integer().to_ty(tcx, false);
424+
}
425+
280426
let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
281427
// The discriminant is less wide than the operand, cast the operand into
282428
// [MaybeUninit; N] and then index into it.
@@ -350,7 +496,8 @@ fn insert_direct_enum_check<'tcx>(
350496
new_block: BasicBlock,
351497
) {
352498
// Insert a new target block that is branched to in case of an invalid discriminant.
353-
let invalid_discr_block_data = BasicBlockData::new(None, false);
499+
let invalid_discr_block_data =
500+
BasicBlockData::new(None, basic_blocks[current_block].is_cleanup);
354501
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
355502
let block_data = &mut basic_blocks[current_block];
356503
let discr_place = insert_discr_cast_to_u128(

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,9 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
694694
body,
695695
&[
696696
// Add some UB checks before any UB gets optimized away.
697+
&check_enums::CheckEnums,
697698
&check_alignment::CheckAlignment,
698699
&check_null::CheckNull,
699-
&check_enums::CheckEnums,
700700
// Before inlining: trim down MIR with passes to reduce inlining work.
701701

702702
// Has to be done before inlining, otherwise actual call will be almost always inlined.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//@ run-crash
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
#[derive(Copy, Clone)]
8+
enum Single {
9+
A,
10+
}
11+
12+
fn main() {
13+
let illegal_val: u16 = 1;
14+
let illegal_val_ptr = &raw const illegal_val;
15+
let foo: *const std::mem::ManuallyDrop<Single> =
16+
unsafe { std::mem::transmute(illegal_val_ptr) };
17+
18+
let val: Single = unsafe { std::mem::ManuallyDrop::into_inner(*foo) };
19+
println!("{}", val as u16);
20+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
#[repr(u16)]
6+
enum Single {
7+
A,
8+
}
9+
10+
fn main() {
11+
let illegal_val: u16 = 0;
12+
let illegal_val_ptr = &raw const illegal_val;
13+
let foo: *const std::mem::ManuallyDrop<Single> =
14+
unsafe { std::mem::transmute(illegal_val_ptr) };
15+
16+
let val: Single = unsafe { foo.cast::<Single>().read() };
17+
println!("{}", val as u16);
18+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ run-crash
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
#[derive(Copy, Clone)]
8+
enum Single {
9+
A,
10+
}
11+
12+
fn main() {
13+
let illegal_val: u16 = 1;
14+
let illegal_val_ptr = &raw const illegal_val;
15+
let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) };
16+
17+
let val: Single = unsafe { *foo };
18+
println!("{}", val as u16);
19+
}

0 commit comments

Comments
 (0)