Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 147 additions & 52 deletions crates/cairo-lang-lowering/src/analysis/equality_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,33 @@ use crate::{
BlockEnd, BlockId, Lowered, MatchArm, MatchExternInfo, MatchInfo, Statement, VariableId,
};

/// A struct field variable: either a real variable or a unique placeholder for an unknown field.
/// Placeholders are created during merge when a field has no intersection representative.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
enum FieldVar {
Var(VariableId),
/// A globally unique placeholder representing an unknown field.
Placeholder(usize),
}

impl FieldVar {
/// Returns the real variable if this is a `Var`, or `None` if it's a `Placeholder`.
fn as_var(self) -> Option<VariableId> {
match self {
FieldVar::Var(v) => Some(v),
FieldVar::Placeholder(_) => None,
}
}

/// Path-compresses the variable inside a `Var`, leaves `Placeholder` unchanged.
fn find_rep(self, info: &mut EqualityState<'_>) -> Self {
match self {
FieldVar::Var(v) => FieldVar::Var(info.find(v)),
p @ FieldVar::Placeholder(_) => p,
}
}
}

/// The kind of relationship between equivalence classes.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
enum RelationKind {
Expand All @@ -35,7 +62,7 @@ enum Relation<'db> {
Box(VariableId),
Snapshot(VariableId),
EnumConstruct(ConcreteVariant<'db>, VariableId),
StructConstruct(TypeId<'db>, Vec<VariableId>),
StructConstruct(TypeId<'db>, Vec<FieldVar>),
}

impl<'db> Relation<'db> {
Expand All @@ -56,13 +83,13 @@ impl<'db> Relation<'db> {
}
}

/// Returns an iterator over all variables referenced by this relation.
/// Returns an iterator over all real variables referenced by this relation.
fn referenced_vars(&self) -> impl Iterator<Item = VariableId> + '_ {
let fields: &[VariableId] = match self {
let fields: &[FieldVar] = match self {
Relation::StructConstruct(_, vs) => vs,
_ => &[],
};
self.single_var().into_iter().chain(fields.iter().copied())
self.single_var().into_iter().chain(fields.iter().filter_map(|f| f.as_var()))
}

/// Creates a new relation of the same kind with the single variable replaced.
Expand Down Expand Up @@ -138,7 +165,33 @@ impl<'db> ClassInfo<'db> {
{
*e.get_mut() = e.get().with_var(union_fn(self_var, other_var));
}
// StructConstruct: keep whichever side we already have (no field-level merge).
// For StructConstruct, merge element-wise: fill in placeholders.
if let (
Relation::StructConstruct(ty, self_fields),
Relation::StructConstruct(_, other_fields),
) = (e.get().clone(), &other_rel)
&& self_fields.len() == other_fields.len()
{
let merged: Vec<FieldVar> = self_fields
.iter()
.zip(other_fields.iter())
.map(|(sf, of)| match (sf, of) {
(FieldVar::Var(a), FieldVar::Var(b)) => {
if a == b {
FieldVar::Var(*a)
} else {
FieldVar::Var(union_fn(*a, *b))
}
}
(FieldVar::Var(a), FieldVar::Placeholder(_)) => FieldVar::Var(*a),
(FieldVar::Placeholder(_), FieldVar::Var(b)) => FieldVar::Var(*b),
(FieldVar::Placeholder(p), FieldVar::Placeholder(_)) => {
FieldVar::Placeholder(*p)
}
})
.collect();
*e.get_mut() = Relation::StructConstruct(ty, merged);
}
}
Entry::Vacant(e) => {
e.insert(other_rel);
Expand Down Expand Up @@ -316,21 +369,15 @@ impl<'db> EqualityState<'db> {
}

/// Looks up the struct construct info for a representative (immutable).
fn get_struct_construct_immut(
&self,
rep: VariableId,
) -> Option<(TypeId<'db>, Vec<VariableId>)> {
fn get_struct_construct_immut(&self, rep: VariableId) -> Option<(TypeId<'db>, Vec<FieldVar>)> {
match self.class_info.get(&rep)?.reverse_relationship.get(&RelationKind::StructConstruct)? {
Relation::StructConstruct(ty, fields) => Some((*ty, fields.clone())),
_ => unreachable!(),
}
}

/// Looks up the struct construct info for a variable (mutable, uses find for path compression).
fn get_struct_construct(
&mut self,
var: VariableId,
) -> Option<(TypeId<'db>, Vec<VariableId>)> {
fn get_struct_construct(&mut self, var: VariableId) -> Option<(TypeId<'db>, Vec<FieldVar>)> {
let rep = self.find(var);
self.get_struct_construct_immut(rep)
}
Expand All @@ -346,14 +393,15 @@ impl<'db> EqualityState<'db> {
}
}

/// Records a struct construct: output = StructType(inputs...).
/// Records a struct construct: output = StructType(fields...).
/// Fields may contain placeholders from merge operations.
fn set_struct_construct(
&mut self,
ty: TypeId<'db>,
input_reps: Vec<VariableId>,
fields: Vec<FieldVar>,
output: VariableId,
) {
self.set_construct(Relation::StructConstruct(ty, input_reps), output);
self.set_construct(Relation::StructConstruct(ty, fields), output);
}
}

Expand All @@ -379,7 +427,14 @@ impl<'db> DebugWithDb<'db> for EqualityState<'db> {
}
Relation::StructConstruct(ty, inputs) => {
let type_name = ty.format(db);
let fields = inputs.iter().map(|&id| v(id)).collect::<Vec<_>>().join(", ");
let fields = inputs
.iter()
.map(|f| match f {
FieldVar::Var(id) => v(*id),
FieldVar::Placeholder(_) => "?".to_string(),
})
.collect::<Vec<_>>()
.join(", ");
lines.push(format!("{type_name}({fields}) = {}", v(output)));
}
// Box/Snapshot never appear in hashcons — they use class_info relationships.
Expand All @@ -405,6 +460,8 @@ impl<'db> DebugWithDb<'db> for EqualityState<'db> {
pub struct EqualityAnalysis<'a, 'db> {
db: &'db dyn Database,
lowered: &'a Lowered<'db>,
/// Counter for allocating globally unique placeholder IDs.
next_placeholder: usize,
/// The `array_new` extern function id.
array_new: ExternFunctionId<'db>,
/// The `array_append` extern function id.
Expand All @@ -426,6 +483,7 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> {
Self {
db,
lowered,
next_placeholder: 0,
array_new: array_module.extern_function_id("array_new"),
array_append: array_module.extern_function_id("array_append"),
array_pop_front: array_module.extern_function_id("array_pop_front"),
Expand All @@ -435,6 +493,13 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> {
}
}

/// Allocates a fresh, globally unique placeholder ID.
fn alloc_placeholder(&mut self) -> FieldVar {
let id = self.next_placeholder;
self.next_placeholder += 1;
FieldVar::Placeholder(id)
}

/// Runs equality analysis on a lowered function.
/// Returns the equality state at the exit of each block.
pub fn analyze(
Expand Down Expand Up @@ -467,11 +532,14 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> {
[remaining_arr, boxed_elem] => {
let input_arr = extern_info.inputs[0].var_id;
if let Some((ty, elems)) = info.get_struct_construct(input_arr)
&& let Some((&first, rest)) = elems.split_first()
&& let Some((first, rest)) = elems.split_first()
{
info.set_box_relationship(first, boxed_elem);
let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect();
info.set_struct_construct(ty, rest_reps, remaining_arr);
if let FieldVar::Var(first_var) = first {
info.set_box_relationship(*first_var, boxed_elem);
}
let rest_fields: Vec<FieldVar> =
rest.iter().map(|f| f.find_rep(info)).collect();
info.set_struct_construct(ty, rest_fields, remaining_arr);
}
}
// None arm: union output with input.
Expand Down Expand Up @@ -500,31 +568,42 @@ impl<'a, 'db> EqualityAnalysis<'a, 'db> {
})
.or_else(|| info.get_struct_construct_immut(snap_rep));

if let Some((_orig_ty, elems)) = elems_opt {
let pop_front = id == self.array_snapshot_pop_front;
let (elem, rest) = if pop_front {
let Some((&first, tail)) = elems.split_first() else { return };
(first, tail.to_vec())
} else {
let Some((&last, init)) = elems.split_last() else { return };
(last, init.to_vec())
let snap_ty = self.lowered.variables[remaining_snap_arr].ty;
let Some((_orig_ty, elems)) = elems_opt else {
return;
};
let pop_front = id == self.array_snapshot_pop_front;
let (elem, rest) = if pop_front {
let Some((first, tail)) = elems.split_first() else {
// Empty array — record the empty remaining array and return.
info.set_struct_construct(snap_ty, vec![], remaining_snap_arr);
return;
};
(*first, tail.to_vec())
} else {
let Some((last, init)) = elems.split_last() else {
info.set_struct_construct(snap_ty, vec![], remaining_snap_arr);
return;
};
(*last, init.to_vec())
};

// The popped element is `Box<@T>`. Record the box relationship against
// the snapshot class of `elem` if it exists.
let elem_rep = info.find(elem);
// The popped element is `Box<@T>`. Record the box relationship against
// the snapshot class of `elem` if it exists.
if let FieldVar::Var(elem_var) = elem {
let elem_rep = info.find(elem_var);
if let Some(&snap_of_elem) = info
.class_info
.get(&elem_rep)
.and_then(|ci| ci.relationship.get(&RelationKind::Snapshot))
{
info.set_box_relationship(snap_of_elem, boxed_snap_elem);
}

let snap_ty = self.lowered.variables[remaining_snap_arr].ty;
let rest_reps: Vec<_> = rest.iter().map(|&v| info.find(v)).collect();
info.set_struct_construct(snap_ty, rest_reps, remaining_snap_arr);
}

let rest_fields: Vec<FieldVar> =
rest.iter().map(|f| f.find_rep(info)).collect();
info.set_struct_construct(snap_ty, rest_fields, remaining_snap_arr);
}
// None arm: union output with input.
[original_snap_arr] => {
Expand Down Expand Up @@ -611,12 +690,14 @@ fn find_intersection_rep_opt(
}

/// Preserves construct entries (enum, struct) that exist in both branches.
/// Uses output-based lookup via `class_info` reverse_relationships.
/// Uses output-based lookup via `class_info` reverse_relationships, which handles both
/// complete and partial (placeholder-containing) struct constructs uniformly.
fn merge_constructs<'db>(
info1: &EqualityState<'db>,
info2: &EqualityState<'db>,
intersections: &OrderedHashMap<VariableId, Vec<(VariableId, VariableId)>>,
result: &mut EqualityState<'db>,
alloc_placeholder: &mut impl FnMut() -> FieldVar,
) {
for (&rep1, class1) in info1.class_info.iter() {
for &(rep2, intersection_output) in intersections.get(&rep1).unwrap_or(&vec![]) {
Expand All @@ -639,7 +720,7 @@ fn merge_constructs<'db>(
result.set_enum_construct(*variant1, input_intersection, intersection_output);
}

// StructConstruct: all fields must have intersection reps.
// StructConstruct: field-by-field, allowing placeholders for unknown fields.
if let (
Some(Relation::StructConstruct(ty1, fields1)),
Some(Relation::StructConstruct(ty2, fields2)),
Expand All @@ -649,19 +730,27 @@ fn merge_constructs<'db>(
) && ty1 == ty2
&& fields1.len() == fields2.len()
{
let result_fields: Option<Vec<_>> = fields1
let result_fields: Vec<FieldVar> = fields1
.iter()
.zip(fields2.iter())
.map(|(&v1, &v2)| {
find_intersection_rep(
.map(|(f1, f2)| match (f1.as_var(), f2.as_var()) {
(Some(v1), Some(v2)) => find_intersection_rep(
intersections,
info1.find_immut(v1),
info2.find_immut(v2),
)
.map(FieldVar::Var)
.unwrap_or_else(&mut *alloc_placeholder),
_ => alloc_placeholder(),
})
.collect();
if let Some(result_fields) = result_fields {
result.set_struct_construct(*ty1, result_fields, intersection_output);
// Only store if at least one field is a real variable (or empty struct).
if result_fields.is_empty() || result_fields.iter().any(|f| f.as_var().is_some()) {
result.set_struct_construct(
*ty1,
result_fields,
intersection_output,
);
}
}
}
Expand Down Expand Up @@ -722,7 +811,8 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> {

merge_class_relationships(&info1, &info2, &intersections, &mut result);

merge_constructs(&info1, &info2, &intersections, &mut result);
let mut alloc = || self.alloc_placeholder();
merge_constructs(&info1, &info2, &intersections, &mut result, &mut alloc);

result
}
Expand Down Expand Up @@ -770,22 +860,27 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> {
// If we've already seen the same struct type with equivalent inputs, the outputs
// are equal.
let ty = self.lowered.variables[struct_stmt.output].ty;
let input_reps = struct_stmt.inputs.iter().map(|i| info.find(i.var_id)).collect();
info.set_struct_construct(ty, input_reps, struct_stmt.output);
let fields =
struct_stmt.inputs.iter().map(|i| FieldVar::Var(info.find(i.var_id))).collect();
info.set_struct_construct(ty, fields, struct_stmt.output);
}

Statement::StructDestructure(struct_stmt) => {
// (outputs...) = struct_destructure(input)
// 1. If input was previously constructed, union outputs with original fields.
// 1. If input was previously constructed, union outputs with original fields. Skip
// placeholder fields (unknown after merge).
if let Some((_, field_reps)) = info.get_struct_construct(struct_stmt.input.var_id) {
for (&output, &field_rep) in struct_stmt.outputs.iter().zip(field_reps.iter()) {
info.union(output, field_rep);
for (&output, field) in struct_stmt.outputs.iter().zip(field_reps.iter()) {
if let FieldVar::Var(field_rep) = field {
info.union(output, *field_rep);
}
}
}
// 2. Record: struct_construct(outputs) == input (for future constructs).
let ty = self.lowered.variables[struct_stmt.input.var_id].ty;
let output_reps = struct_stmt.outputs.iter().map(|&v| info.find(v)).collect();
info.set_struct_construct(ty, output_reps, struct_stmt.input.var_id);
let fields =
struct_stmt.outputs.iter().map(|&v| FieldVar::Var(info.find(v))).collect();
info.set_struct_construct(ty, fields, struct_stmt.input.var_id);
}

Statement::Call(call_stmt) => {
Expand All @@ -800,7 +895,7 @@ impl<'db, 'a> DataflowAnalyzer<'db, 'a> for EqualityAnalysis<'a, 'db> {
// Only track append if the input array is already tracked. Arrays from
// function parameters or external calls are conservatively ignored.
let mut new_elems = elems;
new_elems.push(info.find(call_stmt.inputs[1].var_id));
new_elems.push(FieldVar::Var(info.find(call_stmt.inputs[1].var_id)));
info.set_struct_construct(ty, new_elems, call_stmt.outputs[0]);
}
}
Expand Down
Loading
Loading