diff --git a/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs b/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs index 80cbb6aac45..7996f11b525 100644 --- a/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs +++ b/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs @@ -19,6 +19,33 @@ use crate::{ BlockEnd, BlockId, Lowered, MatchArm, MatchExternInfo, MatchInfo, Statement, VariableId, }; +/// A struct field variable: either a real variable or a unique placeholder for an unknown field. +/// Placeholders are created during merge when a field has no intersection representative. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +enum FieldVar { + Var(VariableId), + /// A globally unique placeholder representing an unknown field. + Placeholder(usize), +} + +impl FieldVar { + /// Returns the real variable if this is a `Var`, or `None` if it's a `Placeholder`. + fn as_var(self) -> Option { + match self { + FieldVar::Var(v) => Some(v), + FieldVar::Placeholder(_) => None, + } + } + + /// Path-compresses the variable inside a `Var`, leaves `Placeholder` unchanged. + fn find_rep(self, info: &mut EqualityState<'_>) -> Self { + match self { + FieldVar::Var(v) => FieldVar::Var(info.find(v)), + p @ FieldVar::Placeholder(_) => p, + } + } +} + /// The kind of relationship between equivalence classes. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] enum RelationKind { @@ -35,7 +62,7 @@ enum Relation<'db> { Box(VariableId), Snapshot(VariableId), EnumConstruct(ConcreteVariant<'db>, VariableId), - StructConstruct(TypeId<'db>, Vec), + StructConstruct(TypeId<'db>, Vec), } impl<'db> Relation<'db> { @@ -56,13 +83,13 @@ impl<'db> Relation<'db> { } } - /// Returns an iterator over all variables referenced by this relation. + /// Returns an iterator over all real variables referenced by this relation. fn referenced_vars(&self) -> impl Iterator + '_ { - let fields: &[VariableId] = match self { + let fields: &[FieldVar] = match self { Relation::StructConstruct(_, vs) => vs, _ => &[], }; - self.single_var().into_iter().chain(fields.iter().copied()) + self.single_var().into_iter().chain(fields.iter().filter_map(|f| f.as_var())) } /// Creates a new relation of the same kind with the single variable replaced. @@ -138,7 +165,33 @@ impl<'db> ClassInfo<'db> { { *e.get_mut() = e.get().with_var(union_fn(self_var, other_var)); } - // StructConstruct: keep whichever side we already have (no field-level merge). + // For StructConstruct, merge element-wise: fill in placeholders. + if let ( + Relation::StructConstruct(ty, self_fields), + Relation::StructConstruct(_, other_fields), + ) = (e.get().clone(), &other_rel) + && self_fields.len() == other_fields.len() + { + let merged: Vec = self_fields + .iter() + .zip(other_fields.iter()) + .map(|(sf, of)| match (sf, of) { + (FieldVar::Var(a), FieldVar::Var(b)) => { + if a == b { + FieldVar::Var(*a) + } else { + FieldVar::Var(union_fn(*a, *b)) + } + } + (FieldVar::Var(a), FieldVar::Placeholder(_)) => FieldVar::Var(*a), + (FieldVar::Placeholder(_), FieldVar::Var(b)) => FieldVar::Var(*b), + (FieldVar::Placeholder(p), FieldVar::Placeholder(_)) => { + FieldVar::Placeholder(*p) + } + }) + .collect(); + *e.get_mut() = Relation::StructConstruct(ty, merged); + } } Entry::Vacant(e) => { e.insert(other_rel); @@ -316,10 +369,7 @@ impl<'db> EqualityState<'db> { } /// Looks up the struct construct info for a representative (immutable). - fn get_struct_construct_immut( - &self, - rep: VariableId, - ) -> Option<(TypeId<'db>, Vec)> { + fn get_struct_construct_immut(&self, rep: VariableId) -> Option<(TypeId<'db>, Vec)> { match self.class_info.get(&rep)?.reverse_relationship.get(&RelationKind::StructConstruct)? { Relation::StructConstruct(ty, fields) => Some((*ty, fields.clone())), _ => unreachable!(), @@ -327,10 +377,7 @@ impl<'db> EqualityState<'db> { } /// Looks up the struct construct info for a variable (mutable, uses find for path compression). - fn get_struct_construct( - &mut self, - var: VariableId, - ) -> Option<(TypeId<'db>, Vec)> { + fn get_struct_construct(&mut self, var: VariableId) -> Option<(TypeId<'db>, Vec)> { let rep = self.find(var); self.get_struct_construct_immut(rep) } @@ -346,14 +393,15 @@ impl<'db> EqualityState<'db> { } } - /// Records a struct construct: output = StructType(inputs...). + /// Records a struct construct: output = StructType(fields...). + /// Fields may contain placeholders from merge operations. fn set_struct_construct( &mut self, ty: TypeId<'db>, - input_reps: Vec, + fields: Vec, output: VariableId, ) { - self.set_construct(Relation::StructConstruct(ty, input_reps), output); + self.set_construct(Relation::StructConstruct(ty, fields), output); } } @@ -379,7 +427,14 @@ impl<'db> DebugWithDb<'db> for EqualityState<'db> { } Relation::StructConstruct(ty, inputs) => { let type_name = ty.format(db); - let fields = inputs.iter().map(|&id| v(id)).collect::>().join(", "); + let fields = inputs + .iter() + .map(|f| match f { + FieldVar::Var(id) => v(*id), + FieldVar::Placeholder(_) => "?".to_string(), + }) + .collect::>() + .join(", "); lines.push(format!("{type_name}({fields}) = {}", v(output))); } // Box/Snapshot never appear in hashcons — they use class_info relationships. @@ -405,6 +460,8 @@ impl<'db> DebugWithDb<'db> for EqualityState<'db> { pub struct EqualityAnalysis<'a, 'db> { db: &'db dyn Database, lowered: &'a Lowered<'db>, + /// Counter for allocating globally unique placeholder IDs. + next_placeholder: usize, /// The `array_new` extern function id. array_new: ExternFunctionId<'db>, /// The `array_append` extern function id. @@ -426,6 +483,7 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> { Self { db, lowered, + next_placeholder: 0, array_new: array_module.extern_function_id("array_new"), array_append: array_module.extern_function_id("array_append"), array_pop_front: array_module.extern_function_id("array_pop_front"), @@ -435,6 +493,13 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> { } } + /// Allocates a fresh, globally unique placeholder ID. + fn alloc_placeholder(&mut self) -> FieldVar { + let id = self.next_placeholder; + self.next_placeholder += 1; + FieldVar::Placeholder(id) + } + /// Runs equality analysis on a lowered function. /// Returns the equality state at the exit of each block. pub fn analyze( @@ -467,11 +532,14 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> { [remaining_arr, boxed_elem] => { let input_arr = extern_info.inputs[0].var_id; if let Some((ty, elems)) = info.get_struct_construct(input_arr) - && let Some((&first, rest)) = elems.split_first() + && let Some((first, rest)) = elems.split_first() { - info.set_box_relationship(first, boxed_elem); - let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect(); - info.set_struct_construct(ty, rest_reps, remaining_arr); + if let FieldVar::Var(first_var) = first { + info.set_box_relationship(*first_var, boxed_elem); + } + let rest_fields: Vec = + rest.iter().map(|f| f.find_rep(info)).collect(); + info.set_struct_construct(ty, rest_fields, remaining_arr); } } // None arm: union output with input. @@ -500,19 +568,30 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> { }) .or_else(|| info.get_struct_construct_immut(snap_rep)); - if let Some((_orig_ty, elems)) = elems_opt { - let pop_front = id == self.array_snapshot_pop_front; - let (elem, rest) = if pop_front { - let Some((&first, tail)) = elems.split_first() else { return }; - (first, tail.to_vec()) - } else { - let Some((&last, init)) = elems.split_last() else { return }; - (last, init.to_vec()) + let snap_ty = self.lowered.variables[remaining_snap_arr].ty; + let Some((_orig_ty, elems)) = elems_opt else { + return; + }; + let pop_front = id == self.array_snapshot_pop_front; + let (elem, rest) = if pop_front { + let Some((first, tail)) = elems.split_first() else { + // Empty array — record the empty remaining array and return. + info.set_struct_construct(snap_ty, vec![], remaining_snap_arr); + return; + }; + (*first, tail.to_vec()) + } else { + let Some((last, init)) = elems.split_last() else { + info.set_struct_construct(snap_ty, vec![], remaining_snap_arr); + return; }; + (*last, init.to_vec()) + }; - // The popped element is `Box<@T>`. Record the box relationship against - // the snapshot class of `elem` if it exists. - let elem_rep = info.find(elem); + // The popped element is `Box<@T>`. Record the box relationship against + // the snapshot class of `elem` if it exists. + if let FieldVar::Var(elem_var) = elem { + let elem_rep = info.find(elem_var); if let Some(&snap_of_elem) = info .class_info .get(&elem_rep) @@ -520,11 +599,11 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> { { info.set_box_relationship(snap_of_elem, boxed_snap_elem); } - - let snap_ty = self.lowered.variables[remaining_snap_arr].ty; - let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect(); - info.set_struct_construct(snap_ty, rest_reps, remaining_snap_arr); } + + let rest_fields: Vec = + rest.iter().map(|f| f.find_rep(info)).collect(); + info.set_struct_construct(snap_ty, rest_fields, remaining_snap_arr); } // None arm: union output with input. [original_snap_arr] => { @@ -611,12 +690,14 @@ fn find_intersection_rep_opt( } /// Preserves construct entries (enum, struct) that exist in both branches. -/// Uses output-based lookup via `class_info` reverse_relationships. +/// Uses output-based lookup via `class_info` reverse_relationships, which handles both +/// complete and partial (placeholder-containing) struct constructs uniformly. fn merge_constructs<'db>( info1: &EqualityState<'db>, info2: &EqualityState<'db>, intersections: &OrderedHashMap>, result: &mut EqualityState<'db>, + alloc_placeholder: &mut impl FnMut() -> FieldVar, ) { for (&rep1, class1) in info1.class_info.iter() { for &(rep2, intersection_output) in intersections.get(&rep1).unwrap_or(&vec![]) { @@ -639,7 +720,7 @@ fn merge_constructs<'db>( result.set_enum_construct(*variant1, input_intersection, intersection_output); } - // StructConstruct: all fields must have intersection reps. + // StructConstruct: field-by-field, allowing placeholders for unknown fields. if let ( Some(Relation::StructConstruct(ty1, fields1)), Some(Relation::StructConstruct(ty2, fields2)), @@ -649,19 +730,27 @@ fn merge_constructs<'db>( ) && ty1 == ty2 && fields1.len() == fields2.len() { - let result_fields: Option> = fields1 + let result_fields: Vec = fields1 .iter() .zip(fields2.iter()) - .map(|(&v1, &v2)| { - find_intersection_rep( + .map(|(f1, f2)| match (f1.as_var(), f2.as_var()) { + (Some(v1), Some(v2)) => find_intersection_rep( intersections, info1.find_immut(v1), info2.find_immut(v2), ) + .map(FieldVar::Var) + .unwrap_or_else(&mut *alloc_placeholder), + _ => alloc_placeholder(), }) .collect(); - if let Some(result_fields) = result_fields { - result.set_struct_construct(*ty1, result_fields, intersection_output); + // Only store if at least one field is a real variable (or empty struct). + if result_fields.is_empty() || result_fields.iter().any(|f| f.as_var().is_some()) { + result.set_struct_construct( + *ty1, + result_fields, + intersection_output, + ); } } } @@ -722,7 +811,8 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { merge_class_relationships(&info1, &info2, &intersections, &mut result); - merge_constructs(&info1, &info2, &intersections, &mut result); + let mut alloc = || self.alloc_placeholder(); + merge_constructs(&info1, &info2, &intersections, &mut result, &mut alloc); result } @@ -770,22 +860,27 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { // If we've already seen the same struct type with equivalent inputs, the outputs // are equal. let ty = self.lowered.variables[struct_stmt.output].ty; - let input_reps = struct_stmt.inputs.iter().map(|i| info.find(i.var_id)).collect(); - info.set_struct_construct(ty, input_reps, struct_stmt.output); + let fields = + struct_stmt.inputs.iter().map(|i| FieldVar::Var(info.find(i.var_id))).collect(); + info.set_struct_construct(ty, fields, struct_stmt.output); } Statement::StructDestructure(struct_stmt) => { // (outputs...) = struct_destructure(input) - // 1. If input was previously constructed, union outputs with original fields. + // 1. If input was previously constructed, union outputs with original fields. Skip + // placeholder fields (unknown after merge). if let Some((_, field_reps)) = info.get_struct_construct(struct_stmt.input.var_id) { - for (&output, &field_rep) in struct_stmt.outputs.iter().zip(field_reps.iter()) { - info.union(output, field_rep); + for (&output, field) in struct_stmt.outputs.iter().zip(field_reps.iter()) { + if let FieldVar::Var(field_rep) = field { + info.union(output, *field_rep); + } } } // 2. Record: struct_construct(outputs) == input (for future constructs). let ty = self.lowered.variables[struct_stmt.input.var_id].ty; - let output_reps = struct_stmt.outputs.iter().map(|&v| info.find(v)).collect(); - info.set_struct_construct(ty, output_reps, struct_stmt.input.var_id); + let fields = + struct_stmt.outputs.iter().map(|&v| FieldVar::Var(info.find(v))).collect(); + info.set_struct_construct(ty, fields, struct_stmt.input.var_id); } Statement::Call(call_stmt) => { @@ -800,7 +895,7 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { // Only track append if the input array is already tracked. Arrays from // function parameters or external calls are conservatively ignored. let mut new_elems = elems; - new_elems.push(info.find(call_stmt.inputs[1].var_id)); + new_elems.push(FieldVar::Var(info.find(call_stmt.inputs[1].var_id))); info.set_struct_construct(ty, new_elems, call_stmt.outputs[0]); } } diff --git a/crates/cairo-lang-lowering/src/analysis/test_data/equality b/crates/cairo-lang-lowering/src/analysis/test_data/equality index 228d739a296..cd46d2a2280 100644 --- a/crates/cairo-lang-lowering/src/analysis/test_data/equality +++ b/crates/cairo-lang-lowering/src/analysis/test_data/equality @@ -1383,3 +1383,88 @@ Block 3: Block 4: @v4 = v6, core::array::Array::() = v2, core::array::Array::(v0) = v3, core::array::Array::(v0, v1) = v4, v4 = v5, v6 = v9 + +//! > ========================================================================== + +//! > Test array construct across branches with partial element agreement. + +//! > Both branches build an array with different first element but same second element. + +//! > After merge, the struct hashcons entry survives with a placeholder for the differing + +//! > element, preserving the common element's identity. + +//! > test_runner_name +test_equality_analysis + +//! > function_code +fn foo(cond: bool, a: felt252, b: felt252) { + let arr = if cond { + let mut arr = ArrayTrait::new(); + arr.append(a); + arr.append(b); + arr + } else { + let mut arr = ArrayTrait::new(); + arr.append(a + 1); + arr.append(b); + arr + }; + use_arr(@arr); +} + +//! > function_name +foo + +//! > module_code +extern fn use_arr(x: @Array) nopanic; + +//! > semantic_diagnostics + +//! > lowering +Parameters: v0: core::bool, v1: core::felt252, v2: core::felt252 +blk0 (root): +Statements: +End: + Match(match_enum(v0) { + bool::False(v3) => blk1, + bool::True(v4) => blk2, + }) + +blk1: +Statements: + (v5: core::array::Array::) <- core::array::array_new::() + (v6: core::felt252) <- 1 + (v7: core::felt252) <- core::felt252_add(v1, v6) + (v8: core::array::Array::) <- core::array::array_append::(v5, v7) + (v9: core::array::Array::) <- core::array::array_append::(v8, v2) +End: + Goto(blk3, {v9 -> v10}) + +blk2: +Statements: + (v11: core::array::Array::) <- core::array::array_new::() + (v12: core::array::Array::) <- core::array::array_append::(v11, v1) + (v13: core::array::Array::) <- core::array::array_append::(v12, v2) +End: + Goto(blk3, {v13 -> v10}) + +blk3: +Statements: + (v14: core::array::Array::, v15: @core::array::Array::) <- snapshot(v10) + () <- test::use_arr(v15) +End: + Return() + +//! > analysis_state +Block 0: +(empty) + +Block 1: +False(v3) = v0, core::array::Array::() = v5, core::array::Array::(v7) = v8, core::array::Array::(v7, v2) = v9 + +Block 2: +True(v4) = v0, core::array::Array::() = v11, core::array::Array::(v1) = v12, core::array::Array::(v1, v2) = v13 + +Block 3: +@v10 = v15, core::array::Array::(?, v2) = v10, v10 = v14