Skip to content

Commit 8ae0303

Browse files
committed
Extend the enum check to pointer and union reads
This change extends the previously added enum discriminant check to enums read through a union or pointer. At the moment we only insert the check when transmuting to an enum. Although I hoped for it, this check isn't yet inserted for calls to `MaybeUninit::assume_init`, because the pass is running on polymorphic MIR and thus doesn't have the information yet to know whether the type that is read is an enum.
1 parent fc5af18 commit 8ae0303

File tree

10 files changed

+366
-65
lines changed

10 files changed

+366
-65
lines changed

compiler/rustc_mir_transform/src/check_enums.rs

Lines changed: 211 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
1+
use rustc_abi::{HasDataLayout, Scalar, Size, TagEncoding, Variants, WrappingRange};
22
use rustc_hir::LangItem;
33
use rustc_index::IndexVec;
44
use rustc_middle::bug;
5-
use rustc_middle::mir::visit::Visitor;
5+
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
66
use rustc_middle::mir::*;
7-
use rustc_middle::ty::layout::PrimitiveExt;
8-
use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
7+
use rustc_middle::ty::layout::{IntegerExt, PrimitiveExt};
8+
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypingEnv};
99
use rustc_session::Session;
1010
use tracing::debug;
1111

@@ -148,79 +148,202 @@ impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
148148
fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
149149
self.enums
150150
}
151+
152+
/// Registers a new enum check in the finder.
153+
fn register_new_check(
154+
&mut self,
155+
enum_ty: Ty<'tcx>,
156+
enum_def: AdtDef<'tcx>,
157+
source_op: Operand<'tcx>,
158+
) {
159+
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(enum_ty)) else {
160+
return;
161+
};
162+
// If the operand is a pointer, we want to pass on the size of the operand to the check,
163+
// as we will dereference the pointer and look at the value directly.
164+
let Ok(op_layout) = (if let ty::RawPtr(pointee_ty, _) =
165+
source_op.ty(self.local_decls, self.tcx).kind()
166+
{
167+
self.tcx.layout_of(self.typing_env.as_query_input(*pointee_ty))
168+
} else {
169+
self.tcx
170+
.layout_of(self.typing_env.as_query_input(source_op.ty(self.local_decls, self.tcx)))
171+
}) else {
172+
return;
173+
};
174+
175+
match enum_layout.variants {
176+
Variants::Empty if op_layout.is_uninhabited() => return,
177+
// An empty enum that tries to be constructed from an inhabited value, this
178+
// is never correct.
179+
Variants::Empty => {
180+
// The enum layout is uninhabited but we construct it from sth inhabited.
181+
// This is always UB.
182+
self.enums.push(EnumCheckType::Uninhabited);
183+
}
184+
// Construction of Single value enums is always fine.
185+
Variants::Single { .. } => {}
186+
// Construction of an enum with multiple variants but no niche optimizations.
187+
Variants::Multiple {
188+
tag_encoding: TagEncoding::Direct,
189+
tag: Scalar::Initialized { value, .. },
190+
..
191+
} => {
192+
let valid_discrs =
193+
enum_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
194+
195+
let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
196+
self.enums.push(EnumCheckType::Direct {
197+
source_op: source_op.to_copy(),
198+
discr,
199+
op_size: op_layout.size,
200+
valid_discrs,
201+
});
202+
}
203+
// Construction of an enum with multiple variants and niche optimizations.
204+
Variants::Multiple {
205+
tag_encoding: TagEncoding::Niche { .. },
206+
tag: Scalar::Initialized { value, valid_range, .. },
207+
tag_field,
208+
..
209+
} => {
210+
let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
211+
self.enums.push(EnumCheckType::WithNiche {
212+
source_op: source_op.to_copy(),
213+
discr,
214+
op_size: op_layout.size,
215+
offset: enum_layout.fields.offset(tag_field.as_usize()),
216+
valid_range,
217+
});
218+
}
219+
_ => return,
220+
}
221+
}
151222
}
152223

153224
impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
154-
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
155-
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
156-
let ty::Adt(adt_def, _) = ty.kind() else {
157-
return;
158-
};
159-
if !adt_def.is_enum() {
225+
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
226+
self.super_place(place, context, location);
227+
// We only want to emit this check on pointer reads.
228+
match context {
229+
PlaceContext::NonMutatingUse(
230+
NonMutatingUseContext::Copy
231+
| NonMutatingUseContext::Move
232+
| NonMutatingUseContext::SharedBorrow,
233+
) => {}
234+
_ => {
160235
return;
161236
}
237+
}
162238

163-
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
239+
if !place.is_indirect() {
240+
return;
241+
}
242+
// Get the place and type we visit.
243+
let pointer = Place::from(place.local);
244+
let pointer_ty = pointer.ty(self.local_decls, self.tcx).ty;
245+
246+
// We only want to check places based on raw pointers to enums or ManuallyDrop<Enum>.
247+
let &ty::RawPtr(pointee_ty, _) = pointer_ty.kind() else {
248+
return;
249+
};
250+
let ty::Adt(enum_adt_def, _) = pointee_ty.kind() else {
251+
return;
252+
};
253+
254+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
255+
(pointee_ty, enum_adt_def)
256+
} else if enum_adt_def.is_manually_drop() {
257+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
258+
let Some((manual_drop_arg, adt_def)) =
259+
pointee_ty.walk().skip(1).next().map_or(None, |arg| {
260+
if let Some(ty) = arg.as_type()
261+
&& let ty::Adt(adt_def, _) = ty.kind()
262+
{
263+
Some((ty, adt_def))
264+
} else {
265+
None
266+
}
267+
})
268+
else {
164269
return;
165270
};
166-
let Ok(op_layout) = self
167-
.tcx
168-
.layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
271+
272+
(manual_drop_arg, adt_def)
273+
} else {
274+
return;
275+
};
276+
// Exclude c_void.
277+
if enum_ty.is_c_void(self.tcx) {
278+
return;
279+
}
280+
281+
self.register_new_check(enum_ty, *enum_adt_def, Operand::Copy(*place));
282+
}
283+
284+
fn visit_projection_elem(
285+
&mut self,
286+
place_ref: PlaceRef<'tcx>,
287+
elem: PlaceElem<'tcx>,
288+
context: visit::PlaceContext,
289+
location: Location,
290+
) {
291+
self.super_projection_elem(place_ref, elem, context, location);
292+
// Check whether we are reading an enum or a ManuallyDrop<Enum> from a union.
293+
let ty::Adt(union_adt_def, _) = place_ref.ty(self.local_decls, self.tcx).ty.kind() else {
294+
return;
295+
};
296+
if !union_adt_def.is_union() {
297+
return;
298+
}
299+
let PlaceElem::Field(_, extracted_ty) = elem else {
300+
return;
301+
};
302+
let ty::Adt(enum_adt_def, _) = extracted_ty.kind() else {
303+
return;
304+
};
305+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
306+
(extracted_ty, enum_adt_def)
307+
} else if enum_adt_def.is_manually_drop() {
308+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
309+
let Some((manual_drop_arg, adt_def)) =
310+
extracted_ty.walk().skip(1).next().map_or(None, |arg| {
311+
if let Some(ty) = arg.as_type()
312+
&& let ty::Adt(adt_def, _) = ty.kind()
313+
{
314+
Some((ty, adt_def))
315+
} else {
316+
None
317+
}
318+
})
169319
else {
170320
return;
171321
};
172322

173-
match enum_layout.variants {
174-
Variants::Empty if op_layout.is_uninhabited() => return,
175-
// An empty enum that tries to be constructed from an inhabited value, this
176-
// is never correct.
177-
Variants::Empty => {
178-
// The enum layout is uninhabited but we construct it from sth inhabited.
179-
// This is always UB.
180-
self.enums.push(EnumCheckType::Uninhabited);
181-
}
182-
// Construction of Single value enums is always fine.
183-
Variants::Single { .. } => {}
184-
// Construction of an enum with multiple variants but no niche optimizations.
185-
Variants::Multiple {
186-
tag_encoding: TagEncoding::Direct,
187-
tag: Scalar::Initialized { value, .. },
188-
..
189-
} => {
190-
let valid_discrs =
191-
adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
192-
193-
let discr =
194-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
195-
self.enums.push(EnumCheckType::Direct {
196-
source_op: op.to_copy(),
197-
discr,
198-
op_size: op_layout.size,
199-
valid_discrs,
200-
});
201-
}
202-
// Construction of an enum with multiple variants and niche optimizations.
203-
Variants::Multiple {
204-
tag_encoding: TagEncoding::Niche { .. },
205-
tag: Scalar::Initialized { value, valid_range, .. },
206-
tag_field,
207-
..
208-
} => {
209-
let discr =
210-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
211-
self.enums.push(EnumCheckType::WithNiche {
212-
source_op: op.to_copy(),
213-
discr,
214-
op_size: op_layout.size,
215-
offset: enum_layout.fields.offset(tag_field.as_usize()),
216-
valid_range,
217-
});
218-
}
219-
_ => return,
323+
(manual_drop_arg, adt_def)
324+
} else {
325+
return;
326+
};
327+
328+
self.register_new_check(
329+
enum_ty,
330+
*enum_adt_def,
331+
Operand::Copy(place_ref.to_place(self.tcx)),
332+
);
333+
}
334+
335+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
336+
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
337+
let ty::Adt(adt_def, _) = ty.kind() else {
338+
return;
339+
};
340+
if !adt_def.is_enum() {
341+
return;
220342
}
221343

222-
self.super_rvalue(rvalue, location);
344+
self.register_new_check(*ty, *adt_def, op.to_copy());
223345
}
346+
self.super_rvalue(rvalue, location);
224347
}
225348
}
226349

@@ -246,7 +369,7 @@ fn insert_discr_cast_to_u128<'tcx>(
246369
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
247370
block_data: &mut BasicBlockData<'tcx>,
248371
source_op: Operand<'tcx>,
249-
discr: TyAndSize<'tcx>,
372+
mut discr: TyAndSize<'tcx>,
250373
op_size: Size,
251374
offset: Option<Size>,
252375
source_info: SourceInfo,
@@ -262,6 +385,29 @@ fn insert_discr_cast_to_u128<'tcx>(
262385
}
263386
};
264387

388+
// If the enum is behind a pointer, cast it to a *[const|mut] MaybeUninit<T> and then extract the discriminant through that.
389+
let source_op = if let ty::RawPtr(pointee_ty, mutbl) = source_op.ty(local_decls, tcx).kind()
390+
&& !discr.ty.is_raw_ptr()
391+
{
392+
let mu_ptr_ty = Ty::new_ptr(tcx, Ty::new_maybe_uninit(tcx, *pointee_ty), *mutbl);
393+
let mu_ptr_decl =
394+
local_decls.push(LocalDecl::with_source_info(mu_ptr_ty, source_info)).into();
395+
let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_ptr_ty);
396+
block_data.statements.push(Statement::new(
397+
source_info,
398+
StatementKind::Assign(Box::new((mu_ptr_decl, rvalue))),
399+
));
400+
401+
Operand::Copy(mu_ptr_decl.project_deeper(&[ProjectionElem::Deref], tcx))
402+
} else {
403+
source_op
404+
};
405+
406+
// Correct the discriminant ty to an integer, to not screw up our casts to the discriminant ty.
407+
if discr.ty.is_raw_ptr() {
408+
discr.ty = tcx.data_layout().ptr_sized_integer().to_ty(tcx, false);
409+
}
410+
265411
let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
266412
// The discriminant is less wide than the operand, cast the operand into
267413
// [MaybeUninit; N] and then index into it.
@@ -335,7 +481,8 @@ fn insert_direct_enum_check<'tcx>(
335481
new_block: BasicBlock,
336482
) {
337483
// Insert a new target block that is branched to in case of an invalid discriminant.
338-
let invalid_discr_block_data = BasicBlockData::new(None, false);
484+
let invalid_discr_block_data =
485+
BasicBlockData::new(None, basic_blocks[current_block].is_cleanup);
339486
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340487
let block_data = &mut basic_blocks[current_block];
341488
let discr_place = insert_discr_cast_to_u128(

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,9 +665,9 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
665665
body,
666666
&[
667667
// Add some UB checks before any UB gets optimized away.
668+
&check_enums::CheckEnums,
668669
&check_alignment::CheckAlignment,
669670
&check_null::CheckNull,
670-
&check_enums::CheckEnums,
671671
// Before inlining: trim down MIR with passes to reduce inlining work.
672672

673673
// Has to be done before inlining, otherwise actual call will be almost always inlined.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//@ run-crash
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
#[derive(Copy, Clone)]
8+
enum Single {
9+
A,
10+
}
11+
12+
fn main() {
13+
let illegal_val: u16 = 1;
14+
let illegal_val_ptr = &raw const illegal_val;
15+
let foo: *const std::mem::ManuallyDrop<Single> =
16+
unsafe { std::mem::transmute(illegal_val_ptr) };
17+
18+
let val: Single = unsafe { std::mem::ManuallyDrop::into_inner(*foo) };
19+
println!("{}", val as u16);
20+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
#[repr(u16)]
6+
enum Single {
7+
A,
8+
}
9+
10+
fn main() {
11+
let illegal_val: u16 = 0;
12+
let illegal_val_ptr = &raw const illegal_val;
13+
let foo: *const std::mem::ManuallyDrop<Single> =
14+
unsafe { std::mem::transmute(illegal_val_ptr) };
15+
16+
let val: Single = unsafe { foo.cast::<Single>().read() };
17+
println!("{}", val as u16);
18+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ run-crash
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
#[derive(Copy, Clone)]
8+
enum Single {
9+
A,
10+
}
11+
12+
fn main() {
13+
let illegal_val: u16 = 1;
14+
let illegal_val_ptr = &raw const illegal_val;
15+
let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) };
16+
17+
let val: Single = unsafe { *foo };
18+
println!("{}", val as u16);
19+
}

0 commit comments

Comments
 (0)