@@ -45,6 +45,8 @@ struct UnsafetyVisitor<'a, 'tcx> {
4545 /// Flag to ensure that we only suggest wrapping the entire function body in
4646 /// an unsafe block once.
4747 suggest_unsafe_block : bool ,
48+ /// Controls how union field accesses are checked
49+ union_field_access_mode : UnionFieldAccessMode ,
4850}
4951
5052impl < ' tcx > UnsafetyVisitor < ' _ , ' tcx > {
@@ -218,6 +220,7 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
218220 inside_adt : false ,
219221 warnings : self . warnings ,
220222 suggest_unsafe_block : self . suggest_unsafe_block ,
223+ union_field_access_mode : UnionFieldAccessMode :: Normal ,
221224 } ;
222225 // params in THIR may be unsafe, e.g. a union pattern.
223226 for param in & inner_thir. params {
@@ -658,18 +661,25 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
658661 } else if adt_def. is_union ( ) {
659662 // Check if this field access is part of a raw borrow operation
660663 // If so, we've already handled it above and shouldn't reach here
661- if let Some ( assigned_ty) = self . assignment_info {
662- if assigned_ty. needs_drop ( self . tcx , self . typing_env ) {
663- // This would be unsafe, but should be outright impossible since we
664- // reject such unions.
665- assert ! (
666- self . tcx. dcx( ) . has_errors( ) . is_some( ) ,
667- "union fields that need dropping should be impossible: {assigned_ty}"
668- ) ;
664+ match self . union_field_access_mode {
665+ UnionFieldAccessMode :: SuppressUnionFieldAccessError => {
666+ // Suppress AccessToUnionField error for union fields chains
667+ }
668+ UnionFieldAccessMode :: Normal => {
669+ if let Some ( assigned_ty) = self . assignment_info {
670+ if assigned_ty. needs_drop ( self . tcx , self . typing_env ) {
671+ // This would be unsafe, but should be outright impossible since we
672+ // reject such unions.
673+ assert ! (
674+ self . tcx. dcx( ) . has_errors( ) . is_some( ) ,
675+ "union fields that need dropping should be impossible: {assigned_ty}"
676+ ) ;
677+ }
678+ } else {
679+ // Only require unsafe if this is not a raw borrow operation
680+ self . requires_unsafe ( expr. span , AccessToUnionField ) ;
681+ }
669682 }
670- } else {
671- // Only require unsafe if this is not a raw borrow operation
672- self . requires_unsafe ( expr. span , AccessToUnionField ) ;
673683 }
674684 }
675685 }
@@ -728,7 +738,7 @@ impl<'a, 'tcx> UnsafetyVisitor<'a, 'tcx> {
728738 match self . thir [ expr_id] . kind {
729739 ExprKind :: Field { lhs, .. } => {
730740 let lhs = & self . thir [ lhs] ;
731- if let ty:: Adt ( adt_def, _) = lhs . ty . kind ( ) { adt_def. is_union ( ) } else { false }
741+ matches ! ( lhs . ty . kind ( ) , ty:: Adt ( adt_def, _) if adt_def. is_union( ) )
732742 }
733743 _ => false ,
734744 }
@@ -737,28 +747,28 @@ impl<'a, 'tcx> UnsafetyVisitor<'a, 'tcx> {
737747 /// Visit a union field access in the context of a raw borrow operation
738748 /// This ensures we still check safety of nested operations while allowing
739749 /// the raw pointer creation itself
740- fn visit_union_field_for_raw_borrow ( & mut self , expr_id : ExprId ) {
741- match self . thir [ expr_id] . kind {
742- ExprKind :: Field { lhs, variant_index, name } => {
743- let lhs_expr = & self . thir [ lhs] ;
744- if let ty:: Adt ( adt_def, _) = lhs_expr. ty . kind ( ) {
745- // Check for unsafe fields but skip the union access check
746- if adt_def. variant ( variant_index) . fields [ name] . safety . is_unsafe ( ) {
747- self . requires_unsafe ( self . thir [ expr_id] . span , UseOfUnsafeField ) ;
748- }
749- // For unions, we don't require unsafe for raw pointer creation
750- // But we still need to check the LHS for safety
751- self . visit_expr ( lhs_expr) ;
752- } else {
753- // Not a union, use normal visiting
754- visit:: walk_expr ( self , & self . thir [ expr_id] ) ;
750+ fn visit_union_field_for_raw_borrow ( & mut self , mut expr_id : ExprId ) {
751+ let prev = self . union_field_access_mode ;
752+ self . union_field_access_mode = UnionFieldAccessMode :: SuppressUnionFieldAccessError ;
753+ // Walk through the chain of union field accesses using while let
754+ while let ExprKind :: Field { lhs, variant_index, name } = self . thir [ expr_id] . kind {
755+ let lhs_expr = & self . thir [ lhs] ;
756+ if let ty:: Adt ( adt_def, _) = lhs_expr. ty . kind ( ) {
757+ // Check for unsafe fields but skip the union access check
758+ if adt_def. variant ( variant_index) . fields [ name] . safety . is_unsafe ( ) {
759+ self . requires_unsafe ( self . thir [ expr_id] . span , UseOfUnsafeField ) ;
755760 }
756- }
757- _ => {
758- // Not a field access, use normal visiting
761+ // If the LHS is also a union field access, keep walking
762+ expr_id = lhs;
763+ } else {
764+ // Not a union, use normal visiting
759765 visit:: walk_expr ( self , & self . thir [ expr_id] ) ;
766+ return ;
760767 }
761768 }
769+ // Visit the base expression for any nested safety checks
770+ self . visit_expr ( & self . thir [ expr_id] ) ;
771+ self . union_field_access_mode = prev;
762772 }
763773}
764774
@@ -770,6 +780,13 @@ enum SafetyContext {
770780 UnsafeBlock { span : Span , hir_id : HirId , used : bool , nested_used_blocks : Vec < NestedUsedBlock > } ,
771781}
772782
783+ /// Controls how union field accesses are checked
784+ #[ derive( Clone , Copy ) ]
785+ enum UnionFieldAccessMode {
786+ Normal ,
787+ SuppressUnionFieldAccessError ,
788+ }
789+
773790#[ derive( Clone , Copy ) ]
774791struct NestedUsedBlock {
775792 hir_id : HirId ,
@@ -1244,6 +1261,7 @@ pub(crate) fn check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
12441261 inside_adt : false ,
12451262 warnings : & mut warnings,
12461263 suggest_unsafe_block : true ,
1264+ union_field_access_mode : UnionFieldAccessMode :: Normal ,
12471265 } ;
12481266 // params in THIR may be unsafe, e.g. a union pattern.
12491267 for param in & thir. params {
0 commit comments