diff --git a/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs b/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs index c782527b9a0..80cbb6aac45 100644 --- a/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs +++ b/crates/cairo-lang-lowering/src/analysis/equality_analysis.rs @@ -10,7 +10,7 @@ use cairo_lang_debug::DebugWithDb; use cairo_lang_defs::ids::{ExternFunctionId, NamedLanguageElementId}; use cairo_lang_semantic::helper::ModuleHelper; use cairo_lang_semantic::{ConcreteVariant, MatchArmSelector, TypeId}; -use cairo_lang_utils::ordered_hash_map::OrderedHashMap; +use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap}; use salsa::Database; use crate::analysis::core::Edge; @@ -19,51 +19,133 @@ use crate::{ BlockEnd, BlockId, Lowered, MatchArm, MatchExternInfo, MatchInfo, Statement, VariableId, }; -/// Tracks relationships between equivalence classes. +/// The kind of relationship between equivalence classes. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +enum RelationKind { + Box, + Snapshot, + EnumConstruct, + StructConstruct, +} + +/// A relationship between equivalence classes, carrying its payload data. +/// Hashable so it can be used as a hashcons key. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +enum Relation<'db> { + Box(VariableId), + Snapshot(VariableId), + EnumConstruct(ConcreteVariant<'db>, VariableId), + StructConstruct(TypeId<'db>, Vec), +} + +impl<'db> Relation<'db> { + fn kind(&self) -> RelationKind { + match self { + Relation::Box(_) => RelationKind::Box, + Relation::Snapshot(_) => RelationKind::Snapshot, + Relation::EnumConstruct(_, _) => RelationKind::EnumConstruct, + Relation::StructConstruct(_, _) => RelationKind::StructConstruct, + } + } + + /// Extracts the single variable for simple (1-to-1) relations. + fn single_var(&self) -> Option { + match self { + Relation::Box(v) | Relation::Snapshot(v) | Relation::EnumConstruct(_, v) => Some(*v), + Relation::StructConstruct(_, _) => None, + } + } + + /// Returns an iterator over all variables referenced by this relation. + fn referenced_vars(&self) -> impl Iterator + '_ { + let fields: &[VariableId] = match self { + Relation::StructConstruct(_, vs) => vs, + _ => &[], + }; + self.single_var().into_iter().chain(fields.iter().copied()) + } + + /// Creates a new relation of the same kind with the single variable replaced. + /// Panics for StructConstruct (which has multiple variables). + fn with_var(&self, var: VariableId) -> Self { + match self { + Relation::Box(_) => Relation::Box(var), + Relation::Snapshot(_) => Relation::Snapshot(var), + Relation::EnumConstruct(v, _) => Relation::EnumConstruct(*v, var), + Relation::StructConstruct(_, _) => { + unreachable!("with_var not supported for StructConstruct") + } + } + } +} + +/// Tracks relationships and construct provenance for an equivalence class representative. #[derive(Clone, Debug, Default)] -struct ClassInfo { - /// If this class has a boxed version, the representative of that class. - boxed_class: Option, - /// If this class has an unboxed version, the representative of that class. - unboxed_class: Option, - /// If this class has a snapshot version, the representative of that class. - snapshot_class: Option, - /// If this class is a snapshot, the representative of the original class. - original_class: Option, +struct ClassInfo<'db> { + /// Forward relationships: this class → target class. + /// Box: my boxed version is `target`. Snapshot: my snapshot version is `target`. + /// Only used for Box and Snapshot (1-to-1 bidirectional relationships). + relationship: OrderedHashMap, + /// Reverse relationships: how this class was produced. + /// Box: I was produced by boxing `source`. Snapshot: I am a snapshot of `source`. + /// EnumConstruct: I was produced by Variant(`input`). + /// StructConstruct: I was produced by Type(`fields`). + reverse_relationship: OrderedHashMap>, } -impl ClassInfo { +impl<'db> ClassInfo<'db> { /// Returns all variables referenced by this ClassInfo's relationships. fn referenced_vars(&self) -> impl Iterator + '_ { - [self.boxed_class, self.original_class, self.snapshot_class, self.unboxed_class] - .into_iter() - .flatten() + self.relationship + .values() + .copied() + .chain(self.reverse_relationship.values().flat_map(Relation::referenced_vars)) } - /// Returns true if this ClassInfo has no relationships. + /// Returns true if this ClassInfo has no relationships or construct info. fn is_empty(&self) -> bool { - self.referenced_vars().next().is_none() + self.relationship.is_empty() && self.reverse_relationship.is_empty() } /// Merges another ClassInfo into this one. /// When both have the same relationship type, calls union_fn to merge the related classes. fn merge( - self, + mut self, other: Self, union_fn: &mut impl FnMut(VariableId, VariableId) -> VariableId, ) -> Self { - let mut merge_field = |new: Option, old: Option| match (new, old) { - (Some(new_val), Some(old_val)) if new_val != old_val => { - Some(union_fn(new_val, old_val)) + // Merge forward relationships. + for (kind, other_var) in other.relationship { + match self.relationship.entry(kind) { + Entry::Occupied(mut e) => { + if *e.get() != other_var { + *e.get_mut() = union_fn(*e.get(), other_var); + } + } + Entry::Vacant(e) => { + e.insert(other_var); + } + } + } + // Merge reverse relationships. + for (kind, other_rel) in other.reverse_relationship { + match self.reverse_relationship.entry(kind) { + Entry::Occupied(mut e) => { + // For simple relations (Box, Snapshot, EnumConstruct), union the vars. + if let (Some(self_var), Some(other_var)) = + (e.get().single_var(), other_rel.single_var()) + && self_var != other_var + { + *e.get_mut() = e.get().with_var(union_fn(self_var, other_var)); + } + // StructConstruct: keep whichever side we already have (no field-level merge). + } + Entry::Vacant(e) => { + e.insert(other_rel); + } } - (new, old) => new.or(old), - }; - Self { - boxed_class: merge_field(self.boxed_class, other.boxed_class), - unboxed_class: merge_field(self.unboxed_class, other.unboxed_class), - snapshot_class: merge_field(self.snapshot_class, other.snapshot_class), - original_class: merge_field(self.original_class, other.original_class), } + self } } @@ -79,33 +161,19 @@ pub struct EqualityState<'db> { /// Union-find parent map. If a variable is not in the map, it is its own representative. union_find: OrderedHashMap, - /// For each equivalence class representative, track relationships only if they exist. - class_info: OrderedHashMap, + /// For each equivalence class representative, track relationships and construct provenance. + class_info: OrderedHashMap>, - /// Hashcons for enum constructs: maps (variant, input_rep) -> output_rep. - /// This allows us to detect when two enum constructs with the same variant - /// and equivalent inputs should produce equivalent outputs. + /// Unified hashcons: maps Relation(inputs) -> output representative. + /// This allows us to detect when two constructs with equivalent inputs should produce + /// equivalent outputs. Covers enum constructs, struct/array constructs. /// - /// Keys use representatives at insertion time. In SSA form each variable is defined - /// exactly once, so representatives cannot change within a block and keys stay valid - /// without migration during `union`. At merge points the maps are rebuilt from scratch. - enum_hashcons: OrderedHashMap<(ConcreteVariant<'db>, VariableId), VariableId>, - - /// Reverse hashcons for enum constructs: maps output_rep -> (variant, input_rep). - /// This allows efficient lookup when matching on an enum to find the original input. - enum_hashcons_rev: OrderedHashMap, VariableId)>, - - /// Hashcons for struct/array constructs: maps (type, [field_reps...]) -> output_rep. - /// This allows us to detect when two constructs with the same type - /// and equivalent fields/elements should produce equivalent outputs. - /// Arrays reuse this same infrastructure — `array_new`/`array_append` chains are recorded - /// as constructs keyed by the array type, and `array_pop_front` acts as a destructure. - struct_hashcons: OrderedHashMap<(TypeId<'db>, Vec), VariableId>, - - /// Reverse hashcons for struct/array constructs: maps output_rep -> (type, - /// [field_reps/element_reps...]). - /// This allows efficient lookup when destructuring a struct or popping from an array. - struct_hashcons_rev: OrderedHashMap, Vec)>, + /// Keys use representatives at insertion time. In SSA form, representatives are generally + /// stable within a block, so keys stay valid without migration during `union`. A union + /// *can* change a representative to a lower ID, which may cause a subsequent identical + /// construct to miss the earlier entry — this is a known imprecision (conservative, not + /// unsound). At merge points the maps are rebuilt from scratch. + hashcons: OrderedHashMap, VariableId>, } impl<'db> EqualityState<'db> { @@ -114,11 +182,6 @@ impl<'db> EqualityState<'db> { self.union_find.get(&var).copied().unwrap_or(var) } - /// Gets the class info for a variable, returning a default if not present. - fn get_class_info(&self, var: VariableId) -> ClassInfo { - self.class_info.get(&var).cloned().unwrap_or_default() - } - /// Finds the representative of a variable's equivalence class. /// Uses path compression for efficiency. fn find(&mut self, var: VariableId) -> VariableId { @@ -175,65 +238,73 @@ impl<'db> EqualityState<'db> { self.find(new_root) } - /// Looks up a related variable through a ClassInfo field accessor. - fn get_related( - &mut self, - var: VariableId, - field: fn(&mut ClassInfo) -> &mut Option, - ) -> Option { + /// Looks up a forward relationship of the given kind on a variable's class. + fn get_forward(&mut self, var: VariableId, kind: RelationKind) -> Option { let rep = self.find(var); - let mut info = self.get_class_info(rep); - let related = (*field(&mut info))?; + let related = self.class_info.get(&rep)?.relationship.get(&kind).copied()?; Some(self.find(related)) } - /// Sets a bidirectional relationship between two variables' equivalence classes. - /// If inputs already have a relationship of the same kind, unions with the existing class. - fn set_relationship( - &mut self, - var_a: VariableId, - var_b: VariableId, - field_a_to_b: fn(&mut ClassInfo) -> &mut Option, - field_b_to_a: fn(&mut ClassInfo) -> &mut Option, - ) { + /// Looks up a reverse relationship of the given kind on a variable's class. + fn get_reverse(&mut self, var: VariableId, kind: RelationKind) -> Option { + let rep = self.find(var); + let related = self.class_info.get(&rep)?.reverse_relationship.get(&kind)?.single_var()?; + Some(self.find(related)) + } + + /// Sets a bidirectional relationship (Box or Snapshot) between two variables' classes. + /// If either side already has a relationship of the same kind, unions with the existing class. + fn set_relationship(&mut self, source: VariableId, target: VariableId, kind: RelationKind) { // Union with existing relationships if present. - if let Some(existing) = self.get_related(var_a, field_a_to_b) { - self.union(var_b, existing); + if let Some(existing) = self.get_forward(source, kind) { + self.union(target, existing); } - if let Some(existing) = self.get_related(var_b, field_b_to_a) { - self.union(var_a, existing); + if let Some(existing) = self.get_reverse(target, kind) { + self.union(source, existing); } // Re-find after potential unions — representatives may have changed. - let rep_a = self.find(var_a); - let rep_b = self.find(var_b); - - *field_a_to_b(self.class_info.entry(rep_a).or_default()) = Some(rep_b); - *field_b_to_a(self.class_info.entry(rep_b).or_default()) = Some(rep_a); + let rep_source = self.find(source); + let rep_target = self.find(target); + + self.class_info.entry(rep_source).or_default().relationship.insert(kind, rep_target); + let reverse = match kind { + RelationKind::Box => Relation::Box(rep_source), + RelationKind::Snapshot => Relation::Snapshot(rep_source), + _ => unreachable!("set_relationship only for Box/Snapshot"), + }; + self.class_info.entry(rep_target).or_default().reverse_relationship.insert(kind, reverse); } /// Sets a box relationship: boxed_var = Box(unboxed_var). fn set_box_relationship(&mut self, unboxed_var: VariableId, boxed_var: VariableId) { - self.set_relationship( - unboxed_var, - boxed_var, - |ci| &mut ci.boxed_class, - |ci| &mut ci.unboxed_class, - ); + self.set_relationship(unboxed_var, boxed_var, RelationKind::Box); } /// Sets a snapshot relationship: snapshot_var = @original_var. fn set_snapshot_relationship(&mut self, original_var: VariableId, snapshot_var: VariableId) { - self.set_relationship( - original_var, - snapshot_var, - |ci| &mut ci.snapshot_class, - |ci| &mut ci.original_class, - ); + self.set_relationship(original_var, snapshot_var, RelationKind::Snapshot); + } + + /// Records a construct (enum or struct) via the unified hashcons. + /// If we've already seen the same construct with equivalent inputs, unions the outputs. + fn set_construct(&mut self, relation: Relation<'db>, output: VariableId) { + let output_rep = if let Some(&existing_output) = self.hashcons.get(&relation) { + self.union(existing_output, output); + self.find(existing_output) + } else { + let output_rep = self.find(output); + self.hashcons.insert(relation.clone(), output_rep); + output_rep + }; + self.class_info + .entry(output_rep) + .or_default() + .reverse_relationship + .insert(relation.kind(), relation); } /// Records an enum construct: output = Variant(input). - /// If we've already seen the same variant with an equivalent input, unions the outputs. fn set_enum_construct( &mut self, variant: ConcreteVariant<'db>, @@ -241,41 +312,48 @@ impl<'db> EqualityState<'db> { output: VariableId, ) { let input_rep = self.find(input); - let output_rep = self.find(output); - - match self.enum_hashcons.entry((variant, input_rep)) { - cairo_lang_utils::ordered_hash_map::Entry::Occupied(entry) => { - let existing_output = *entry.get(); - self.union(existing_output, output); - // Union may have changed the representative. Update the reverse map - // so that transfer_edge lookups via find() hit the current representative. - let new_output_rep = self.find(existing_output); - self.enum_hashcons_rev.swap_remove(&existing_output); - self.enum_hashcons_rev.insert(new_output_rep, (variant, input_rep)); - } - cairo_lang_utils::ordered_hash_map::Entry::Vacant(entry) => { - entry.insert(output_rep); - self.enum_hashcons_rev.insert(output_rep, (variant, input_rep)); - } + self.set_construct(Relation::EnumConstruct(variant, input_rep), output); + } + + /// Looks up the struct construct info for a representative (immutable). + fn get_struct_construct_immut( + &self, + rep: VariableId, + ) -> Option<(TypeId<'db>, Vec)> { + match self.class_info.get(&rep)?.reverse_relationship.get(&RelationKind::StructConstruct)? { + Relation::StructConstruct(ty, fields) => Some((*ty, fields.clone())), + _ => unreachable!(), + } + } + + /// Looks up the struct construct info for a variable (mutable, uses find for path compression). + fn get_struct_construct( + &mut self, + var: VariableId, + ) -> Option<(TypeId<'db>, Vec)> { + let rep = self.find(var); + self.get_struct_construct_immut(rep) + } + + /// Looks up the enum construct info for a representative (immutable). + fn get_enum_construct_immut( + &self, + rep: VariableId, + ) -> Option<(ConcreteVariant<'db>, VariableId)> { + match self.class_info.get(&rep)?.reverse_relationship.get(&RelationKind::EnumConstruct)? { + Relation::EnumConstruct(variant, input) => Some((*variant, *input)), + _ => unreachable!(), } } /// Records a struct construct: output = StructType(inputs...). - /// If we've already seen the same type with equivalent inputs, unions the outputs. fn set_struct_construct( &mut self, ty: TypeId<'db>, input_reps: Vec, output: VariableId, ) { - let key = (ty, input_reps); - if let Some(&existing_output) = self.struct_hashcons.get(&key) { - self.union(existing_output, output); - } else { - let output_rep = self.find(output); - self.struct_hashcons.insert(key.clone(), output_rep); - self.struct_hashcons_rev.insert(output_rep, key); - } + self.set_construct(Relation::StructConstruct(ty, input_reps), output); } } @@ -286,21 +364,27 @@ impl<'db> DebugWithDb<'db> for EqualityState<'db> { let v = |id: VariableId| format!("v{}", self.find_immut(id).index()); let mut lines = Vec::::new(); for (&rep, info) in self.class_info.iter() { - if let Some(s) = info.snapshot_class { + if let Some(&s) = info.relationship.get(&RelationKind::Snapshot) { lines.push(format!("@{} = {}", v(rep), v(s))); } - if let Some(b) = info.boxed_class { + if let Some(&b) = info.relationship.get(&RelationKind::Box) { lines.push(format!("Box({}) = {}", v(rep), v(b))); } } - for (&(variant, input), &output) in self.enum_hashcons.iter() { - let name = variant.id.name(db).to_string(db); - lines.push(format!("{name}({}) = {}", v(input), v(output))); - } - for ((ty, inputs), &output) in self.struct_hashcons.iter() { - let type_name = ty.format(db); - let fields = inputs.iter().map(|&id| v(id)).collect::>().join(", "); - lines.push(format!("{type_name}({fields}) = {}", v(output))); + for (relation, &output) in self.hashcons.iter() { + match relation { + Relation::EnumConstruct(variant, input) => { + let name = variant.id.name(db).to_string(db); + lines.push(format!("{name}({}) = {}", v(*input), v(output))); + } + Relation::StructConstruct(ty, inputs) => { + let type_name = ty.format(db); + let fields = inputs.iter().map(|&id| v(id)).collect::>().join(", "); + lines.push(format!("{type_name}({fields}) = {}", v(output))); + } + // Box/Snapshot never appear in hashcons — they use class_info relationships. + _ => {} + } } for &var in self.union_find.keys() { let rep = self.find_immut(var); @@ -378,88 +462,81 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> { let Some((id, _)) = extern_info.function.get_extern(self.db) else { return }; if id == self.array_pop_front || id == self.array_pop_front_consume { - // Some arm: var_ids = [remaining_arr, boxed_elem] - if arm.var_ids.len() == 2 { - let input_arr = extern_info.inputs[0].var_id; - let remaining_arr = arm.var_ids[0]; - let boxed_elem = arm.var_ids[1]; - - let arr_rep = info.find(input_arr); - if let Some((ty, elems)) = info.struct_hashcons_rev.get(&arr_rep).cloned() - && let Some((&first, rest)) = elems.split_first() - { - // Popped element is boxed: boxed_elem = Box(first_element) - info.set_box_relationship(first, boxed_elem); - // Remaining array is the tail - let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect(); - info.set_struct_construct(ty, rest_reps, remaining_arr); + match arm.var_ids[..] { + // Some arm: [remaining_arr, boxed_elem]. + [remaining_arr, boxed_elem] => { + let input_arr = extern_info.inputs[0].var_id; + if let Some((ty, elems)) = info.get_struct_construct(input_arr) + && let Some((&first, rest)) = elems.split_first() + { + info.set_box_relationship(first, boxed_elem); + let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect(); + info.set_struct_construct(ty, rest_reps, remaining_arr); + } } - } - // None arm for array_pop_front: var_ids = [original_arr]. Union with input. - // None arm for array_pop_front_consume: var_ids = []. Nothing to do. - if arm.var_ids.len() == 1 { - info.union(arm.var_ids[0], extern_info.inputs[0].var_id); + // None arm: union output with input. + [original_arr] => { + info.union(original_arr, extern_info.inputs[0].var_id); + } + _ => {} } } else if id == self.array_snapshot_pop_front || id == self.array_snapshot_pop_back { - // Some arm: var_ids = [remaining_snap_arr, boxed_snap_elem] - if arm.var_ids.len() == 2 { - let input_snap_arr = extern_info.inputs[0].var_id; - let remaining_snap_arr = arm.var_ids[0]; - let boxed_snap_elem = arm.var_ids[1]; - - // The input is @Array. Look up the tracked elements. - // Two paths: (1) the input snapshot was created via `snapshot(arr)` where `arr` - // has a struct-hashcons entry under `Array`, or (2) the input is itself a - // remaining snapshot array from a prior snapshot pop, stored directly in the - // struct hashcons under its snapshot type `@Array`. - let snap_rep = info.find(input_snap_arr); - let elems_opt = info - .class_info - .get(&snap_rep) - .and_then(|ci| ci.original_class) - .and_then(|orig| { - let orig = info.find_immut(orig); - info.struct_hashcons_rev.get(&orig).cloned() - }) - .or_else(|| info.struct_hashcons_rev.get(&snap_rep).cloned()); - - if let Some((_orig_ty, elems)) = elems_opt { - let pop_front = id == self.array_snapshot_pop_front; - let (elem, rest) = if pop_front { - let Some((&first, tail)) = elems.split_first() else { return }; - (first, tail.to_vec()) - } else { - let Some((&last, init)) = elems.split_last() else { return }; - (last, init.to_vec()) - }; - - // The popped element is `Box<@T>`. The box wraps the *snapshot* of the - // original element. Record the box relationship against the snapshot - // class of `elem` if it exists, so that `unbox` correctly yields `@elem` - // rather than falsely equating the `@T` result with the `T` original. - let elem_rep = info.find(elem); - if let Some(snap_of_elem) = - info.class_info.get(&elem_rep).and_then(|ci| ci.snapshot_class) - { - info.set_box_relationship(snap_of_elem, boxed_snap_elem); + match arm.var_ids[..] { + // Some arm: [remaining_snap_arr, boxed_snap_elem]. + [remaining_snap_arr, boxed_snap_elem] => { + let input_snap_arr = extern_info.inputs[0].var_id; + + // Look up tracked elements via snapshot reverse relationship or direct lookup. + let snap_rep = info.find(input_snap_arr); + let original_rep = info.class_info.get(&snap_rep).and_then(|ci| { + ci.reverse_relationship + .get(&RelationKind::Snapshot) + .and_then(|r| r.single_var()) + }); + let elems_opt = original_rep + .and_then(|orig| { + let orig = info.find_immut(orig); + info.get_struct_construct_immut(orig) + }) + .or_else(|| info.get_struct_construct_immut(snap_rep)); + + if let Some((_orig_ty, elems)) = elems_opt { + let pop_front = id == self.array_snapshot_pop_front; + let (elem, rest) = if pop_front { + let Some((&first, tail)) = elems.split_first() else { return }; + (first, tail.to_vec()) + } else { + let Some((&last, init)) = elems.split_last() else { return }; + (last, init.to_vec()) + }; + + // The popped element is `Box<@T>`. Record the box relationship against + // the snapshot class of `elem` if it exists. + let elem_rep = info.find(elem); + if let Some(&snap_of_elem) = info + .class_info + .get(&elem_rep) + .and_then(|ci| ci.relationship.get(&RelationKind::Snapshot)) + { + info.set_box_relationship(snap_of_elem, boxed_snap_elem); + } + + let snap_ty = self.lowered.variables[remaining_snap_arr].ty; + let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect(); + info.set_struct_construct(snap_ty, rest_reps, remaining_snap_arr); } - - // Record the remaining snapshot array under its actual snapshot type - // (`@Array`) to avoid falsely equating it with non-snapshot arrays. - let snap_ty = self.lowered.variables[remaining_snap_arr].ty; - let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect(); - info.set_struct_construct(snap_ty, rest_reps, remaining_snap_arr); } - } - // None arm: var_ids = [original_snap_arr]. Union with input. - if arm.var_ids.len() == 1 { - info.union(arm.var_ids[0], extern_info.inputs[0].var_id); + // None arm: union output with input. + [original_snap_arr] => { + info.union(original_snap_arr, extern_info.inputs[0].var_id); + } + _ => {} } } } } -/// Returns an iterator over all variables with equality ir relationship information in the equality +/// Returns an iterator over all variables with equality or relationship information in the equality /// states. fn merge_referenced_vars<'db, 'a>( info1: &'a EqualityState<'db>, @@ -473,18 +550,12 @@ fn merge_referenced_vars<'db, 'a>( .chain(info2.class_info.values()) .flat_map(ClassInfo::referenced_vars); - let enum_vars = info1 - .enum_hashcons - .iter() - .chain(info2.enum_hashcons.iter()) - .flat_map(|(&(_, input), &output)| [input, output]); - - let struct_vars = - info1.struct_hashcons.iter().chain(info2.struct_hashcons.iter()).flat_map( - |((_, inputs), &output)| inputs.iter().copied().chain(std::iter::once(output)), - ); + let hashcons_vars = + info1.hashcons.iter().chain(info2.hashcons.iter()).flat_map(|(relation, &output)| { + relation.referenced_vars().chain(std::iter::once(output)) + }); - union_find_vars.chain(class_info_vars).chain(enum_vars).chain(struct_vars) + union_find_vars.chain(class_info_vars).chain(hashcons_vars) } /// Preserves only class relationships (box/snapshot) that exist in both branches. @@ -500,24 +571,16 @@ fn merge_class_relationships( continue; }; - if let Some(boxed_rep) = find_intersection_rep_opt( - info1, - info2, - intersections, - class1.boxed_class, - class2.boxed_class, - ) { - result.set_box_relationship(intersection_var, boxed_rep); - } - - if let Some(snap_rep) = find_intersection_rep_opt( - info1, - info2, - intersections, - class1.snapshot_class, - class2.snapshot_class, - ) { - result.set_snapshot_relationship(intersection_var, snap_rep); + for &kind in &[RelationKind::Box, RelationKind::Snapshot] { + if let Some(target_rep) = find_intersection_rep_opt( + info1, + info2, + intersections, + class1.relationship.get(&kind).copied(), + class2.relationship.get(&kind).copied(), + ) { + result.set_relationship(intersection_var, target_rep, kind); + } } } } @@ -547,68 +610,64 @@ fn find_intersection_rep_opt( find_intersection_rep(intersections, info1.find_immut(rep1?), info2.find_immut(rep2?)) } -/// Preserves enum hashcons entries that exist in both branches. -/// An entry survives if both input and output have intersection representatives, and info2 has the -/// same relation. -fn merge_enum_hashcons<'db>( +/// Preserves construct entries (enum, struct) that exist in both branches. +/// Uses output-based lookup via `class_info` reverse_relationships. +fn merge_constructs<'db>( info1: &EqualityState<'db>, info2: &EqualityState<'db>, intersections: &OrderedHashMap>, result: &mut EqualityState<'db>, ) { - for (&(variant, input1), &output1) in info1.enum_hashcons.iter() { - for &(input_rep2, input_intersection) in intersections.get(&input1).unwrap_or(&vec![]) { - let output2 = info2.enum_hashcons.get(&(variant, input_rep2)).copied(); - let Some(output_intersection) = - find_intersection_rep_opt(info1, info2, intersections, Some(output1), output2) - else { - continue; - }; + for (&rep1, class1) in info1.class_info.iter() { + for &(rep2, intersection_output) in intersections.get(&rep1).unwrap_or(&vec![]) { + let Some(class2) = info2.class_info.get(&rep2) else { continue }; + + // EnumConstruct: both must have same variant and intersecting input. + if let ( + Some(Relation::EnumConstruct(variant1, input1)), + Some(Relation::EnumConstruct(variant2, input2)), + ) = ( + class1.reverse_relationship.get(&RelationKind::EnumConstruct), + class2.reverse_relationship.get(&RelationKind::EnumConstruct), + ) && variant1 == variant2 + && let Some(input_intersection) = find_intersection_rep( + intersections, + info1.find_immut(*input1), + info2.find_immut(*input2), + ) + { + result.set_enum_construct(*variant1, input_intersection, intersection_output); + } - result.set_enum_construct(variant, input_intersection, output_intersection); + // StructConstruct: all fields must have intersection reps. + if let ( + Some(Relation::StructConstruct(ty1, fields1)), + Some(Relation::StructConstruct(ty2, fields2)), + ) = ( + class1.reverse_relationship.get(&RelationKind::StructConstruct), + class2.reverse_relationship.get(&RelationKind::StructConstruct), + ) && ty1 == ty2 + && fields1.len() == fields2.len() + { + let result_fields: Option> = fields1 + .iter() + .zip(fields2.iter()) + .map(|(&v1, &v2)| { + find_intersection_rep( + intersections, + info1.find_immut(v1), + info2.find_immut(v2), + ) + }) + .collect(); + if let Some(result_fields) = result_fields { + result.set_struct_construct(*ty1, result_fields, intersection_output); + } + } } } } -/// Preserves struct hashcons entries that exist in both branches. -/// An entry survives if all inputs and output have intersection representatives, and info2 has the -/// same relation. -fn merge_struct_hashcons<'db>( - info1: &EqualityState<'db>, - info2: &EqualityState<'db>, - intersections: &OrderedHashMap>, - result: &mut EqualityState<'db>, -) { - for ((ty, inputs1), &output1) in info1.struct_hashcons.iter() { - let input_reps2: Vec<_> = inputs1.iter().map(|&v| info2.find_immut(v)).collect(); - - let Some(&output2) = info2.struct_hashcons.get(&(*ty, input_reps2.clone())) else { - continue; - }; - - let result_inputs: Option> = inputs1 - .iter() - .zip(&input_reps2) - .map(|(&v, &rep2)| { - find_intersection_rep(intersections, info1.find_immut(v), info2.find_immut(rep2)) - }) - .collect(); - let Some(result_inputs) = result_inputs else { - continue; - }; - - let Some(output_intersection) = find_intersection_rep( - intersections, - info1.find_immut(output1), - info2.find_immut(output2), - ) else { - continue; - }; - - result.set_struct_construct(*ty, result_inputs, output_intersection); - } -} - impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { type Info = EqualityState<'db>; @@ -663,9 +722,7 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { merge_class_relationships(&info1, &info2, &intersections, &mut result); - merge_enum_hashcons(&info1, &info2, &intersections, &mut result); - - merge_struct_hashcons(&info1, &info2, &intersections, &mut result); + merge_constructs(&info1, &info2, &intersections, &mut result); result } @@ -720,8 +777,7 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { Statement::StructDestructure(struct_stmt) => { // (outputs...) = struct_destructure(input) // 1. If input was previously constructed, union outputs with original fields. - let input_rep = info.find(struct_stmt.input.var_id); - if let Some((_, field_reps)) = info.struct_hashcons_rev.get(&input_rep).cloned() { + if let Some((_, field_reps)) = info.get_struct_construct(struct_stmt.input.var_id) { for (&output, &field_rep) in struct_stmt.outputs.iter().zip(field_reps.iter()) { info.union(output, field_rep); } @@ -737,13 +793,15 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { if id == self.array_new { let ty = self.lowered.variables[call_stmt.outputs[0]].ty; info.set_struct_construct(ty, vec![], call_stmt.outputs[0]); - } else if id == self.array_append { - let arr_rep = info.find(call_stmt.inputs[0].var_id); - if let Some((ty, elems)) = info.struct_hashcons_rev.get(&arr_rep).cloned() { - let mut new_elems = elems; - new_elems.push(info.find(call_stmt.inputs[1].var_id)); - info.set_struct_construct(ty, new_elems, call_stmt.outputs[0]); - } + } else if id == self.array_append + && let Some((ty, elems)) = + info.get_struct_construct(call_stmt.inputs[0].var_id) + { + // Only track append if the input array is already tracked. Arrays from + // function parameters or external calls are conservatively ignored. + let mut new_elems = elems; + new_elems.push(info.find(call_stmt.inputs[1].var_id)); + info.set_struct_construct(ty, new_elems, call_stmt.outputs[0]); } } @@ -773,7 +831,7 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> { // happen after optimizations merge states from different branches. let output_rep = new_info.find(matched_var); if let Some((old_variant, input)) = - new_info.enum_hashcons_rev.get(&output_rep).copied() + new_info.get_enum_construct_immut(output_rep) && variant == old_variant { new_info.union(arm_var, input); diff --git a/crates/cairo-lang-lowering/src/analysis/test_data/equality b/crates/cairo-lang-lowering/src/analysis/test_data/equality index a86ee242890..228d739a296 100644 --- a/crates/cairo-lang-lowering/src/analysis/test_data/equality +++ b/crates/cairo-lang-lowering/src/analysis/test_data/equality @@ -1253,10 +1253,42 @@ extern fn use_snap_felt(x: @felt252) nopanic; //! > semantic_diagnostics //! > lowering -todo +Parameters: v0: core::felt252, v1: core::felt252 +blk0 (root): +Statements: + (v2: core::felt252, v3: @core::felt252) <- snapshot(v0) + () <- test::use_snap_felt(v3) + (v4: core::array::Array::) <- core::array::array_new::() + (v5: core::array::Array::) <- core::array::array_append::(v4, v0) + (v6: core::array::Array::) <- core::array::array_append::(v5, v1) + (v7: core::array::Array::, v8: @core::array::Array::) <- snapshot(v6) +End: + Match(match core::array::array_snapshot_pop_front::(v8) { + Option::Some(v9, v10) => blk1, + Option::None(v11) => blk2, + }) + +blk1: +Statements: + (v12: @core::felt252) <- unbox(v10) + () <- test::use_snap_felt(v12) +End: + Return() + +blk2: +Statements: +End: + Return() //! > analysis_state -todo +Block 0: +@v0 = v3, @v6 = v8, core::array::Array::() = v4, core::array::Array::(v0) = v5, core::array::Array::(v0, v1) = v6, v0 = v2, v6 = v7 + +Block 1: +@core::array::Array::(v1) = v9, @v0 = v3, @v6 = v8, Box(v3) = v10, core::array::Array::() = v4, core::array::Array::(v0) = v5, core::array::Array::(v0, v1) = v6, v0 = v2, v3 = v12, v6 = v7 + +Block 2: +@v0 = v3, @v6 = v8, core::array::Array::() = v4, core::array::Array::(v0) = v5, core::array::Array::(v0, v1) = v6, v0 = v2, v6 = v7, v8 = v11 //! > ========================================================================== @@ -1296,7 +1328,58 @@ extern fn use_snap_felt(x: @felt252) nopanic; //! > semantic_diagnostics //! > lowering -todo +Parameters: v0: core::felt252, v1: core::felt252 +blk0 (root): +Statements: + (v2: core::array::Array::) <- core::array::array_new::() + (v3: core::array::Array::) <- core::array::array_append::(v2, v0) + (v4: core::array::Array::) <- core::array::array_append::(v3, v1) + (v5: core::array::Array::, v6: @core::array::Array::) <- snapshot(v4) +End: + Match(match core::array::array_snapshot_pop_front::(v6) { + Option::Some(v7, v8) => blk1, + Option::None(v9) => blk4, + }) + +blk1: +Statements: + (v10: @core::felt252) <- unbox(v8) + () <- test::use_snap_felt(v10) +End: + Match(match core::array::array_snapshot_pop_front::(v7) { + Option::Some(v11, v12) => blk2, + Option::None(v13) => blk3, + }) + +blk2: +Statements: + (v14: @core::felt252) <- unbox(v12) + () <- test::use_snap_felt(v14) +End: + Return() + +blk3: +Statements: +End: + Return() + +blk4: +Statements: +End: + Return() //! > analysis_state -todo +Block 0: +@v4 = v6, core::array::Array::() = v2, core::array::Array::(v0) = v3, core::array::Array::(v0, v1) = v4, v4 = v5 + +Block 1: +@core::array::Array::(v1) = v7, @v4 = v6, Box(v10) = v8, core::array::Array::() = v2, core::array::Array::(v0) = v3, core::array::Array::(v0, v1) = v4, v4 = v5 + +Block 2: +@core::array::Array::() = v11, @core::array::Array::(v1) = v7, @v4 = v6, Box(v10) = v8, Box(v14) = v12, core::array::Array::() = v2, core::array::Array::(v0) = v3, core::array::Array::(v0, v1) = v4, v4 = v5 + +Block 3: +@core::array::Array::(v1) = v7, @v4 = v6, Box(v10) = v8, core::array::Array::() = v2, core::array::Array::(v0) = v3, core::array::Array::(v0, v1) = v4, v4 = v5, v7 = v13 + +Block 4: +@v4 = v6, core::array::Array::() = v2, core::array::Array::(v0) = v3, core::array::Array::(v0, v1) = v4, v4 = v5, v6 = v9