Skip to content

Commit 3db681f

Browse files
committed
fixed unsafe suggestion
1 parent d51b6f9 commit 3db681f

File tree

4 files changed

+170
-22
lines changed

4 files changed

+170
-22
lines changed

compiler/rustc_mir_build/src/check_unsafety.rs

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ struct UnsafetyVisitor<'a, 'tcx> {
4545
/// Flag to ensure that we only suggest wrapping the entire function body in
4646
/// an unsafe block once.
4747
suggest_unsafe_block: bool,
48+
/// Controls how union field accesses are checked
49+
union_field_access_mode: UnionFieldAccessMode,
4850
}
4951

5052
impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
@@ -223,6 +225,7 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
223225
inside_adt: false,
224226
warnings: self.warnings,
225227
suggest_unsafe_block: self.suggest_unsafe_block,
228+
union_field_access_mode: UnionFieldAccessMode::Normal,
226229
};
227230
// params in THIR may be unsafe, e.g. a union pattern.
228231
for param in &inner_thir.params {
@@ -545,6 +548,20 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
545548
}
546549
}
547550
ExprKind::RawBorrow { arg, .. } => {
551+
// Handle the case where we're taking a raw pointer to a union field
552+
if let ExprKind::Scope { value: arg, .. } = self.thir[arg].kind {
553+
if self.is_union_field_access(arg) {
554+
// Taking a raw pointer to a union field is safe - just check the base expression
555+
// but skip the union field safety check
556+
self.visit_union_field_for_raw_borrow(arg);
557+
return;
558+
}
559+
} else if self.is_union_field_access(arg) {
560+
// Direct raw borrow of union field
561+
self.visit_union_field_for_raw_borrow(arg);
562+
return;
563+
}
564+
548565
if let ExprKind::Scope { value: arg, .. } = self.thir[arg].kind
549566
&& let ExprKind::Deref { arg } = self.thir[arg].kind
550567
{
@@ -649,17 +666,27 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
649666
if adt_def.variant(variant_index).fields[name].safety.is_unsafe() {
650667
self.requires_unsafe(expr.span, UseOfUnsafeField);
651668
} else if adt_def.is_union() {
652-
if let Some(assigned_ty) = self.assignment_info {
653-
if assigned_ty.needs_drop(self.tcx, self.typing_env) {
654-
// This would be unsafe, but should be outright impossible since we
655-
// reject such unions.
656-
assert!(
657-
self.tcx.dcx().has_errors().is_some(),
658-
"union fields that need dropping should be impossible: {assigned_ty}"
659-
);
669+
// Check if this field access is part of a raw borrow operation
670+
// If so, we've already handled it above and shouldn't reach here
671+
match self.union_field_access_mode {
672+
UnionFieldAccessMode::SuppressUnionFieldAccessError => {
673+
// Suppress AccessToUnionField error for union fields chains
674+
}
675+
UnionFieldAccessMode::Normal => {
676+
if let Some(assigned_ty) = self.assignment_info {
677+
if assigned_ty.needs_drop(self.tcx, self.typing_env) {
678+
// This would be unsafe, but should be outright impossible since we
679+
// reject such unions.
680+
assert!(
681+
self.tcx.dcx().has_errors().is_some(),
682+
"union fields that need dropping should be impossible: {assigned_ty}"
683+
);
684+
}
685+
} else {
686+
// Only require unsafe if this is not a raw borrow operation
687+
self.requires_unsafe(expr.span, AccessToUnionField);
688+
}
660689
}
661-
} else {
662-
self.requires_unsafe(expr.span, AccessToUnionField);
663690
}
664691
}
665692
}
@@ -712,6 +739,46 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
712739
}
713740
}
714741

742+
impl<'a, 'tcx> UnsafetyVisitor<'a, 'tcx> {
743+
/// Check if an expression is a union field access
744+
fn is_union_field_access(&self, expr_id: ExprId) -> bool {
745+
match self.thir[expr_id].kind {
746+
ExprKind::Field { lhs, .. } => {
747+
let lhs = &self.thir[lhs];
748+
matches!(lhs.ty.kind(), ty::Adt(adt_def, _) if adt_def.is_union())
749+
}
750+
_ => false,
751+
}
752+
}
753+
754+
/// Visit a union field access in the context of a raw borrow operation
755+
/// This ensures we still check safety of nested operations while allowing
756+
/// the raw pointer creation itself
757+
fn visit_union_field_for_raw_borrow(&mut self, mut expr_id: ExprId) {
758+
let prev = self.union_field_access_mode;
759+
self.union_field_access_mode = UnionFieldAccessMode::SuppressUnionFieldAccessError;
760+
// Walk through the chain of union field accesses using while let
761+
while let ExprKind::Field { lhs, variant_index, name } = self.thir[expr_id].kind {
762+
let lhs_expr = &self.thir[lhs];
763+
if let ty::Adt(adt_def, _) = lhs_expr.ty.kind() {
764+
// Check for unsafe fields but skip the union access check
765+
if adt_def.variant(variant_index).fields[name].safety.is_unsafe() {
766+
self.requires_unsafe(self.thir[expr_id].span, UseOfUnsafeField);
767+
}
768+
// If the LHS is also a union field access, keep walking
769+
expr_id = lhs;
770+
} else {
771+
// Not a union, use normal visiting
772+
visit::walk_expr(self, &self.thir[expr_id]);
773+
return;
774+
}
775+
}
776+
// Visit the base expression for any nested safety checks
777+
self.visit_expr(&self.thir[expr_id]);
778+
self.union_field_access_mode = prev;
779+
}
780+
}
781+
715782
#[derive(Clone)]
716783
enum SafetyContext {
717784
Safe,
@@ -720,6 +787,13 @@ enum SafetyContext {
720787
UnsafeBlock { span: Span, hir_id: HirId, used: bool, nested_used_blocks: Vec<NestedUsedBlock> },
721788
}
722789

790+
/// Controls how union field accesses are checked
791+
#[derive(Clone, Copy)]
792+
enum UnionFieldAccessMode {
793+
Normal,
794+
SuppressUnionFieldAccessError,
795+
}
796+
723797
#[derive(Clone, Copy)]
724798
struct NestedUsedBlock {
725799
hir_id: HirId,
@@ -1199,6 +1273,7 @@ pub(crate) fn check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
11991273
inside_adt: false,
12001274
warnings: &mut warnings,
12011275
suggest_unsafe_block: true,
1276+
union_field_access_mode: UnionFieldAccessMode::Normal,
12021277
};
12031278
// params in THIR may be unsafe, e.g. a union pattern.
12041279
for param in &thir.params {

src/tools/miri/tests/pass/both_borrows/smallvec.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ impl<T, const N: usize> RawSmallVec<T, N> {
2525
}
2626

2727
const fn as_mut_ptr_inline(&mut self) -> *mut T {
28-
(unsafe { &raw mut self.inline }) as *mut T
28+
&raw mut self.inline as *mut T
2929
}
3030

3131
const unsafe fn as_mut_ptr_heap(&mut self) -> *mut T {

tests/ui/union/union-unsafe.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ union U4<T: Copy> {
1717
a: T,
1818
}
1919

20+
union U5 {
21+
a: usize,
22+
}
23+
2024
union URef {
2125
p: &'static mut i32,
2226
}
@@ -31,6 +35,11 @@ fn deref_union_field(mut u: URef) {
3135
*(u.p) = 13; //~ ERROR access to union field is unsafe
3236
}
3337

38+
fn raw_deref_union_field(mut u: URef) {
39+
// This is unsafe because we first dereference u.p (reading uninitialized memory)
40+
let _p = &raw const *(u.p); //~ ERROR access to union field is unsafe
41+
}
42+
3443
fn assign_noncopy_union_field(mut u: URefCell) {
3544
u.a = (ManuallyDrop::new(RefCell::new(0)), 1); // OK (assignment does not drop)
3645
u.a.0 = ManuallyDrop::new(RefCell::new(0)); // OK (assignment does not drop)
@@ -57,6 +66,20 @@ fn main() {
5766
let a = u1.a; //~ ERROR access to union field is unsafe
5867
u1.a = 11; // OK
5968

69+
let mut u2 = U1 { a: 10 };
70+
let a = &raw mut u2.a; // OK
71+
unsafe { *a = 3 };
72+
73+
let mut u3 = U1 { a: 10 };
74+
let a = std::ptr::addr_of_mut!(u3.a); // OK
75+
unsafe { *a = 14 };
76+
77+
let u4 = U5 { a: 2 };
78+
let vec = vec![1, 2, 3];
79+
// This is unsafe because we read u4.a (potentially uninitialized memory)
80+
// to use as an array index
81+
let _a = &raw const vec[u4.a]; //~ ERROR access to union field is unsafe
82+
6083
let U1 { a } = u1; //~ ERROR access to union field is unsafe
6184
if let U1 { a: 12 } = u1 {} //~ ERROR access to union field is unsafe
6285
if let Some(U1 { a: 13 }) = Some(u1) {} //~ ERROR access to union field is unsafe
@@ -73,4 +96,38 @@ fn main() {
7396
let mut u3 = U3 { a: ManuallyDrop::new(String::from("old")) }; // OK
7497
u3.a = ManuallyDrop::new(String::from("new")); // OK (assignment does not drop)
7598
*u3.a = String::from("new"); //~ ERROR access to union field is unsafe
99+
100+
let mut unions = [U1 { a: 1 }, U1 { a: 2 }];
101+
102+
// Array indexing + union field raw borrow - should be OK
103+
let ptr = &raw mut unions[0].a; // OK
104+
let ptr2 = &raw const unions[1].a; // OK
105+
106+
// Test for union fields chain, this should be allowed
107+
#[derive(Copy, Clone)]
108+
union Inner {
109+
a: u8,
110+
}
111+
112+
union MoreInner {
113+
moreinner: ManuallyDrop<Inner>,
114+
}
115+
116+
union LessOuter {
117+
lessouter: ManuallyDrop<MoreInner>,
118+
}
119+
120+
union Outer {
121+
outer: ManuallyDrop<LessOuter>,
122+
}
123+
124+
let super_outer = Outer {
125+
outer: ManuallyDrop::new(LessOuter {
126+
lessouter: ManuallyDrop::new(MoreInner {
127+
moreinner: ManuallyDrop::new(Inner { a: 42 }),
128+
}),
129+
}),
130+
};
131+
132+
let ptr = &raw const super_outer.outer.lessouter.moreinner.a;
76133
}

tests/ui/union/union-unsafe.stderr

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,99 @@
11
error[E0133]: access to union field is unsafe and requires unsafe function or block
2-
--> $DIR/union-unsafe.rs:31:6
2+
--> $DIR/union-unsafe.rs:35:6
33
|
44
LL | *(u.p) = 13;
55
| ^^^^^ access to union field
66
|
77
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
88

99
error[E0133]: access to union field is unsafe and requires unsafe function or block
10-
--> $DIR/union-unsafe.rs:43:6
10+
--> $DIR/union-unsafe.rs:40:26
11+
|
12+
LL | let _p = &raw const *(u.p);
13+
| ^^^^^ access to union field
14+
|
15+
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
16+
17+
error[E0133]: access to union field is unsafe and requires unsafe function or block
18+
--> $DIR/union-unsafe.rs:52:6
1119
|
1220
LL | *u3.a = T::default();
1321
| ^^^^ access to union field
1422
|
1523
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
1624

1725
error[E0133]: access to union field is unsafe and requires unsafe function or block
18-
--> $DIR/union-unsafe.rs:49:6
26+
--> $DIR/union-unsafe.rs:58:6
1927
|
2028
LL | *u3.a = T::default();
2129
| ^^^^ access to union field
2230
|
2331
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
2432

2533
error[E0133]: access to union field is unsafe and requires unsafe function or block
26-
--> $DIR/union-unsafe.rs:57:13
34+
--> $DIR/union-unsafe.rs:66:13
2735
|
2836
LL | let a = u1.a;
2937
| ^^^^ access to union field
3038
|
3139
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
3240

3341
error[E0133]: access to union field is unsafe and requires unsafe function or block
34-
--> $DIR/union-unsafe.rs:60:14
42+
--> $DIR/union-unsafe.rs:81:29
43+
|
44+
LL | let _a = &raw const vec[u4.a];
45+
| ^^^^ access to union field
46+
|
47+
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
48+
49+
error[E0133]: access to union field is unsafe and requires unsafe function or block
50+
--> $DIR/union-unsafe.rs:83:14
3551
|
3652
LL | let U1 { a } = u1;
3753
| ^ access to union field
3854
|
3955
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
4056

4157
error[E0133]: access to union field is unsafe and requires unsafe function or block
42-
--> $DIR/union-unsafe.rs:61:20
58+
--> $DIR/union-unsafe.rs:84:20
4359
|
4460
LL | if let U1 { a: 12 } = u1 {}
4561
| ^^ access to union field
4662
|
4763
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
4864

4965
error[E0133]: access to union field is unsafe and requires unsafe function or block
50-
--> $DIR/union-unsafe.rs:62:25
66+
--> $DIR/union-unsafe.rs:85:25
5167
|
5268
LL | if let Some(U1 { a: 13 }) = Some(u1) {}
5369
| ^^ access to union field
5470
|
5571
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
5672

5773
error[E0133]: access to union field is unsafe and requires unsafe function or block
58-
--> $DIR/union-unsafe.rs:67:6
74+
--> $DIR/union-unsafe.rs:90:6
5975
|
6076
LL | *u2.a = String::from("new");
6177
| ^^^^ access to union field
6278
|
6379
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
6480

6581
error[E0133]: access to union field is unsafe and requires unsafe function or block
66-
--> $DIR/union-unsafe.rs:71:6
82+
--> $DIR/union-unsafe.rs:94:6
6783
|
6884
LL | *u3.a = 1;
6985
| ^^^^ access to union field
7086
|
7187
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
7288

7389
error[E0133]: access to union field is unsafe and requires unsafe function or block
74-
--> $DIR/union-unsafe.rs:75:6
90+
--> $DIR/union-unsafe.rs:98:6
7591
|
7692
LL | *u3.a = String::from("new");
7793
| ^^^^ access to union field
7894
|
7995
= note: the field may not be properly initialized: using uninitialized data will cause undefined behavior
8096

81-
error: aborting due to 10 previous errors
97+
error: aborting due to 12 previous errors
8298

8399
For more information about this error, try `rustc --explain E0133`.

0 commit comments

Comments
 (0)