@@ -8,6 +8,7 @@ use crate::deriving::generic::ty::*;
88use crate :: deriving:: generic:: * ;
99use crate :: deriving:: { path_local, path_std} ;
1010
11+ /// Expands a `#[derive(PartialEq)]` attribute into an implementation for the target item.
1112pub ( crate ) fn expand_deriving_partial_eq (
1213 cx : & ExtCtxt < ' _ > ,
1314 span : Span ,
@@ -16,62 +17,6 @@ pub(crate) fn expand_deriving_partial_eq(
1617 push : & mut dyn FnMut ( Annotatable ) ,
1718 is_const : bool ,
1819) {
19- fn cs_eq ( cx : & ExtCtxt < ' _ > , span : Span , substr : & Substructure < ' _ > ) -> BlockOrExpr {
20- let base = true ;
21- let expr = cs_fold (
22- true , // use foldl
23- cx,
24- span,
25- substr,
26- |cx, fold| match fold {
27- CsFold :: Single ( field) => {
28- let [ other_expr] = & field. other_selflike_exprs [ ..] else {
29- cx. dcx ( )
30- . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
31- } ;
32-
33- // We received arguments of type `&T`. Convert them to type `T` by stripping
34- // any leading `&`. This isn't necessary for type checking, but
35- // it results in better error messages if something goes wrong.
36- //
37- // Note: for arguments that look like `&{ x }`, which occur with packed
38- // structs, this would cause expressions like `{ self.x } == { other.x }`,
39- // which isn't valid Rust syntax. This wouldn't break compilation because these
40- // AST nodes are constructed within the compiler. But it would mean that code
41- // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
42- // syntax, which would be suboptimal. So we wrap these in parens, giving
43- // `({ self.x }) == ({ other.x })`, which is valid syntax.
44- let convert = |expr : & P < Expr > | {
45- if let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) =
46- & expr. kind
47- {
48- if let ExprKind :: Block ( ..) = & inner. kind {
49- // `&{ x }` form: remove the `&`, add parens.
50- cx. expr_paren ( field. span , inner. clone ( ) )
51- } else {
52- // `&x` form: remove the `&`.
53- inner. clone ( )
54- }
55- } else {
56- expr. clone ( )
57- }
58- } ;
59- cx. expr_binary (
60- field. span ,
61- BinOpKind :: Eq ,
62- convert ( & field. self_expr ) ,
63- convert ( other_expr) ,
64- )
65- }
66- CsFold :: Combine ( span, expr1, expr2) => {
67- cx. expr_binary ( span, BinOpKind :: And , expr1, expr2)
68- }
69- CsFold :: Fieldless => cx. expr_bool ( span, base) ,
70- } ,
71- ) ;
72- BlockOrExpr :: new_expr ( expr)
73- }
74-
7520 let structural_trait_def = TraitDef {
7621 span,
7722 path : path_std ! ( marker:: StructuralPartialEq ) ,
@@ -97,7 +42,9 @@ pub(crate) fn expand_deriving_partial_eq(
9742 ret_ty: Path ( path_local!( bool ) ) ,
9843 attributes: thin_vec![ cx. attr_word( sym:: inline, span) ] ,
9944 fieldless_variants_strategy: FieldlessVariantsStrategy :: Unify ,
100- combine_substructure: combine_substructure( Box :: new( |a, b, c| cs_eq( a, b, c) ) ) ,
45+ combine_substructure: combine_substructure( Box :: new( |a, b, c| {
46+ BlockOrExpr :: new_expr( get_substructure_equality_expr( a, b, c) )
47+ } ) ) ,
10148 } ] ;
10249
10350 let trait_def = TraitDef {
@@ -113,3 +60,142 @@ pub(crate) fn expand_deriving_partial_eq(
11360 } ;
11461 trait_def. expand ( cx, mitem, item, push)
11562}
63+
64+ /// Generates the equality expression for a struct or enum variant when deriving `PartialEq`.
65+ ///
66+ /// This function generates an expression that checks if all fields of a struct or enum variant are equal.
67+ /// - Scalar fields are compared first for efficiency, followed by compound fields.
68+ /// - If there are no fields, returns `true` (fieldless types are always equal).
69+ ///
70+ /// Whether a field is considered "scalar" is determined by comparing the symbol of its type
71+ /// to a set of known scalar type symbols (e.g., `i32`, `u8`, etc). This check is based on
72+ /// the type's symbol.
73+ ///
74+ /// ### Example 1
75+ /// ```
76+ /// #[derive(PartialEq)]
77+ /// struct i32;
78+ ///
79+ /// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not the primitive),
80+ /// // it will not be treated as scalar. The function will still check equality of `field_2` first
81+ /// // because the symbol matches `i32`.
82+ /// #[derive(PartialEq)]
83+ /// struct Struct {
84+ /// field_1: &'static str,
85+ /// field_2: i32,
86+ /// }
87+ /// ```
88+ ///
89+ /// ### Example 2
90+ /// ```
91+ /// mod ty {
92+ /// pub type i32 = i32;
93+ /// }
94+ ///
95+ /// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
96+ /// // However, the function will not reorder the fields because the symbol for `ty::i32`
97+ /// // does not match the symbol for the primitive `i32` ("ty::i32" != "i32").
98+ /// #[derive(PartialEq)]
99+ /// struct Struct {
100+ /// field_1: &'static str,
101+ /// field_2: ty::i32,
102+ /// }
103+ /// ```
104+ ///
105+ /// For enums, the discriminant is compared first, then the rest of the fields.
106+ ///
107+ /// # Panics
108+ ///
109+ /// If called on static or all-fieldless enums/structs, which should not occur during derive expansion.
110+ fn get_substructure_equality_expr (
111+ cx : & ExtCtxt < ' _ > ,
112+ span : Span ,
113+ substructure : & Substructure < ' _ > ,
114+ ) -> P < Expr > {
115+ use SubstructureFields :: * ;
116+
117+ match substructure. fields {
118+ EnumMatching ( .., fields) | Struct ( .., fields) => {
119+ let combine = move |acc, field| {
120+ let rhs = get_field_equality_expr ( cx, field) ;
121+ if let Some ( lhs) = acc {
122+ // Combine the previous comparison with the current field using logical AND.
123+ return Some ( cx. expr_binary ( field. span , BinOpKind :: And , lhs, rhs) ) ;
124+ }
125+ // Start the chain with the first field's comparison.
126+ Some ( rhs)
127+ } ;
128+
129+ // First compare scalar fields, then compound fields, combining all with logical AND.
130+ return fields
131+ . iter ( )
132+ . filter ( |field| !field. maybe_scalar )
133+ . fold ( fields. iter ( ) . filter ( |field| field. maybe_scalar ) . fold ( None , combine) , combine)
134+ // If there are no fields, treat as always equal.
135+ . unwrap_or_else ( || cx. expr_bool ( span, true ) ) ;
136+ }
137+ EnumDiscr ( disc, match_expr) => {
138+ let lhs = get_field_equality_expr ( cx, disc) ;
139+ let Some ( match_expr) = match_expr else {
140+ return lhs;
141+ } ;
142+ // Compare the discriminant first (cheaper), then the rest of the fields.
143+ return cx. expr_binary ( disc. span , BinOpKind :: And , lhs, match_expr. clone ( ) ) ;
144+ }
145+ StaticEnum ( ..) => cx. dcx ( ) . span_bug (
146+ span,
147+ "unexpected static enum encountered during `derive(PartialEq)` expansion" ,
148+ ) ,
149+ StaticStruct ( ..) => cx. dcx ( ) . span_bug (
150+ span,
151+ "unexpected static struct encountered during `derive(PartialEq)` expansion" ,
152+ ) ,
153+ AllFieldlessEnum ( ..) => cx. dcx ( ) . span_bug (
154+ span,
155+ "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion" ,
156+ ) ,
157+ }
158+ }
159+
160+ /// Generates an equality comparison expression for a single struct or enum field.
161+ ///
162+ /// This function produces an AST expression that compares the `self` and `other` values for a field using `==`.
163+ /// It removes any leading references from both sides for readability.
164+ /// If the field is a block expression, it is wrapped in parentheses to ensure valid syntax.
165+ ///
166+ /// # Panics
167+ ///
168+ /// Panics if there are not exactly two arguments to compare (should be `self` and `other`).
169+ fn get_field_equality_expr ( cx : & ExtCtxt < ' _ > , field : & FieldInfo ) -> P < Expr > {
170+ let [ rhs] = & field. other_selflike_exprs [ ..] else {
171+ cx. dcx ( ) . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
172+ } ;
173+
174+ cx. expr_binary (
175+ field. span ,
176+ BinOpKind :: Eq ,
177+ wrap_block_expr ( cx, peel_refs ( & field. self_expr ) ) ,
178+ wrap_block_expr ( cx, peel_refs ( rhs) ) ,
179+ )
180+ }
181+
182+ /// Removes all leading immutable references from an expression.
183+ ///
184+ /// This is used to strip away any number of leading `&` from an expression (e.g., `&&&T` becomes `T`).
185+ /// Only removes immutable references; mutable references are preserved.
186+ fn peel_refs ( mut expr : & P < Expr > ) -> P < Expr > {
187+ while let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) = & expr. kind {
188+ expr = & inner;
189+ }
190+ expr. clone ( )
191+ }
192+
193+ /// Wraps a block expression in parentheses to ensure valid AST in macro expansion output.
194+ ///
195+ /// If the given expression is a block, it is wrapped in parentheses; otherwise, it is returned unchanged.
196+ fn wrap_block_expr ( cx : & ExtCtxt < ' _ > , expr : P < Expr > ) -> P < Expr > {
197+ if matches ! ( & expr. kind, ExprKind :: Block ( ..) ) {
198+ return cx. expr_paren ( expr. span , expr) ;
199+ }
200+ expr
201+ }
0 commit comments