Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions source/vir/src/recursion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct Ctxt<'a> {
num_decreases: Option<usize>,
scc_rep: Node,
ctx: &'a Ctx,
/// Types of the caller's decreases clauses (used to validate compatibility with callees)
caller_decreases_typs: Vec<Typ>,
}

// Get edges, skipping past SpanInfo
Expand Down Expand Up @@ -209,13 +211,33 @@ fn check_decrease_call(
);
decreases_exps.push(SpannedTyped::new(&span, &dec_exp.typ, e_decx));
}
check_decrease(
ctxt.ctx,
span,
None,
&decreases_exps,
ctxt.num_decreases.expect("num_decreases"),
)

// Validate that the caller's and callee's decreases types are compatible.
// All functions in a mutually recursive group must have decreases clauses
// whose types match at each position (both int or both non-int).
let num_decreases = ctxt.num_decreases.expect("num_decreases");
let num_to_check = std::cmp::min(num_decreases, decreases_exps.len());
for i in 0..num_to_check {
let caller_is_int = height_is_int(&ctxt.caller_decreases_typs[i]);
let callee_is_int = height_is_int(&decreases_exps[i].typ);
if caller_is_int != callee_is_int {
let caller_kind = if caller_is_int { "int" } else { "datatype" };
let callee_kind = if callee_is_int { "int" } else { "datatype" };
return Err(error(
span,
format!(
"in mutually recursive functions, decreases clauses at the same position \
must have compatible types: the caller's decreases clause #{} has {} type, \
but the callee's has {} type",
i + 1,
caller_kind,
callee_kind
),
));
}
}

check_decrease(ctxt.ctx, span, None, &decreases_exps, num_decreases)
}

pub(crate) fn fun_is_recursive(ctx: &Ctx, function: &Function) -> bool {
Expand Down Expand Up @@ -294,6 +316,7 @@ pub(crate) fn rewrite_spec_recursive_fun_with_fueled_rec_call(
num_decreases: None,
scc_rep: scc_rep.clone(),
ctx,
caller_decreases_typs: vec![], // Not used when num_decreases is None
};

// New body: substitute rec%f(args, fuel) for f(args)
Expand Down Expand Up @@ -359,11 +382,14 @@ fn check_termination<'a>(
expr_to_exp_skip_checks(ctx, diagnostics, &params_to_pars(&function.x.params, true), e)
})?;
let scc_rep = ctx.global.func_call_graph.get_scc_rep(&Node::Fun(function.x.name.clone()));
let caller_decreases_typs: Vec<Typ> =
decreases_exps.iter().map(|exp| height_typ(ctx, exp)).collect();
let ctxt = Ctxt {
recursive_function_name: function.x.name.clone(),
num_decreases: Some(num_decreases),
scc_rep,
ctx,
caller_decreases_typs,
};
let stm = map_stm_visitor(body, &mut |s| match &s.x {
StmX::Call { fun, resolved_method, args, dest, .. }
Expand Down