Skip to content

search graph: improve rebasing and add forced ambiguity support #143054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compiler/rustc_next_trait_solver/src/solve/search_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ where
) -> QueryResult<I> {
match kind {
PathKind::Coinductive => response_no_constraints(cx, input, Certainty::Yes),
PathKind::Unknown => response_no_constraints(cx, input, Certainty::overflow(false)),
PathKind::Unknown | PathKind::ForcedAmbiguity => {
response_no_constraints(cx, input, Certainty::overflow(false))
}
// Even though we know these cycles to be unproductive, we still return
// overflow during coherence. This is both as we are not 100% confident in
// the implementation yet and any incorrect errors would be unsound there.
Expand Down
182 changes: 93 additions & 89 deletions compiler/rustc_type_ir/src/search_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ use std::marker::PhantomData;
use derive_where::derive_where;
#[cfg(feature = "nightly")]
use rustc_macros::{Decodable_NoContext, Encodable_NoContext, HashStable_NoContext};
use rustc_type_ir::data_structures::HashMap;
use tracing::{debug, instrument};

use crate::data_structures::HashMap;

mod stack;
use stack::{Stack, StackDepth, StackEntry};
mod global_cache;
Expand Down Expand Up @@ -137,6 +136,12 @@ pub enum PathKind {
Unknown,
/// A path with at least one coinductive step. Such cycles hold.
Coinductive,
/// A path which is treated as ambiguous. Once a path has this path kind
/// any other segment does not change its kind.
///
/// This is currently only used when fuzzing to support negative reasoning.
/// For more details, see #143054.
ForcedAmbiguity,
}

impl PathKind {
Expand All @@ -149,6 +154,9 @@ impl PathKind {
/// to `max(self, rest)`.
fn extend(self, rest: PathKind) -> PathKind {
match (self, rest) {
(PathKind::ForcedAmbiguity, _) | (_, PathKind::ForcedAmbiguity) => {
PathKind::ForcedAmbiguity
}
(PathKind::Coinductive, _) | (_, PathKind::Coinductive) => PathKind::Coinductive,
(PathKind::Unknown, _) | (_, PathKind::Unknown) => PathKind::Unknown,
(PathKind::Inductive, PathKind::Inductive) => PathKind::Inductive,
Expand Down Expand Up @@ -187,41 +195,6 @@ impl UsageKind {
}
}

/// For each goal we track whether the paths from this goal
/// to its cycle heads are coinductive.
///
/// This is a necessary condition to rebase provisional cache
/// entries.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllPathsToHeadCoinductive {
Yes,
No,
}
impl From<PathKind> for AllPathsToHeadCoinductive {
fn from(path: PathKind) -> AllPathsToHeadCoinductive {
match path {
PathKind::Coinductive => AllPathsToHeadCoinductive::Yes,
_ => AllPathsToHeadCoinductive::No,
}
}
}
impl AllPathsToHeadCoinductive {
#[must_use]
fn merge(self, other: impl Into<Self>) -> Self {
match (self, other.into()) {
(AllPathsToHeadCoinductive::Yes, AllPathsToHeadCoinductive::Yes) => {
AllPathsToHeadCoinductive::Yes
}
(AllPathsToHeadCoinductive::No, _) | (_, AllPathsToHeadCoinductive::No) => {
AllPathsToHeadCoinductive::No
}
}
}
fn and_merge(&mut self, other: impl Into<Self>) {
*self = self.merge(other);
}
}

#[derive(Debug, Clone, Copy)]
struct AvailableDepth(usize);
impl AvailableDepth {
Expand Down Expand Up @@ -261,9 +234,9 @@ impl AvailableDepth {
///
/// We also track all paths from this goal to that head. This is necessary
/// when rebasing provisional cache results.
#[derive(Clone, Debug, PartialEq, Eq, Default)]
#[derive(Clone, Debug, Default)]
struct CycleHeads {
heads: BTreeMap<StackDepth, AllPathsToHeadCoinductive>,
heads: BTreeMap<StackDepth, PathsToNested>,
}

impl CycleHeads {
Expand All @@ -283,27 +256,16 @@ impl CycleHeads {
self.heads.first_key_value().map(|(k, _)| *k)
}

fn remove_highest_cycle_head(&mut self) {
fn remove_highest_cycle_head(&mut self) -> PathsToNested {
let last = self.heads.pop_last();
debug_assert_ne!(last, None);
}

fn insert(
&mut self,
head: StackDepth,
path_from_entry: impl Into<AllPathsToHeadCoinductive> + Copy,
) {
self.heads.entry(head).or_insert(path_from_entry.into()).and_merge(path_from_entry);
last.unwrap().1
}

fn merge(&mut self, heads: &CycleHeads) {
for (&head, &path_from_entry) in heads.heads.iter() {
self.insert(head, path_from_entry);
debug_assert!(matches!(self.heads[&head], AllPathsToHeadCoinductive::Yes));
}
fn insert(&mut self, head: StackDepth, path_from_entry: impl Into<PathsToNested> + Copy) {
*self.heads.entry(head).or_insert(path_from_entry.into()) |= path_from_entry.into();
}

fn iter(&self) -> impl Iterator<Item = (StackDepth, AllPathsToHeadCoinductive)> + '_ {
fn iter(&self) -> impl Iterator<Item = (StackDepth, PathsToNested)> + '_ {
self.heads.iter().map(|(k, v)| (*k, *v))
}

Expand All @@ -317,13 +279,7 @@ impl CycleHeads {
Ordering::Equal => continue,
Ordering::Greater => unreachable!(),
}

let path_from_entry = match step_kind {
PathKind::Coinductive => AllPathsToHeadCoinductive::Yes,
PathKind::Unknown | PathKind::Inductive => path_from_entry,
};

self.insert(head, path_from_entry);
self.insert(head, path_from_entry.extend_with(step_kind));
}
}
}
Expand All @@ -332,13 +288,14 @@ bitflags::bitflags! {
/// Tracks how nested goals have been accessed. This is necessary to disable
/// global cache entries if computing them would otherwise result in a cycle or
/// access a provisional cache entry.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PathsToNested: u8 {
/// The initial value when adding a goal to its own nested goals.
const EMPTY = 1 << 0;
const INDUCTIVE = 1 << 1;
const UNKNOWN = 1 << 2;
const COINDUCTIVE = 1 << 3;
const FORCED_AMBIGUITY = 1 << 4;
}
}
impl From<PathKind> for PathsToNested {
Expand All @@ -347,6 +304,7 @@ impl From<PathKind> for PathsToNested {
PathKind::Inductive => PathsToNested::INDUCTIVE,
PathKind::Unknown => PathsToNested::UNKNOWN,
PathKind::Coinductive => PathsToNested::COINDUCTIVE,
PathKind::ForcedAmbiguity => PathsToNested::FORCED_AMBIGUITY,
}
}
}
Expand Down Expand Up @@ -379,10 +337,45 @@ impl PathsToNested {
self.insert(PathsToNested::COINDUCTIVE);
}
}
PathKind::ForcedAmbiguity => {
if self.intersects(
PathsToNested::EMPTY
| PathsToNested::INDUCTIVE
| PathsToNested::UNKNOWN
| PathsToNested::COINDUCTIVE,
) {
self.remove(
PathsToNested::EMPTY
| PathsToNested::INDUCTIVE
| PathsToNested::UNKNOWN
| PathsToNested::COINDUCTIVE,
);
self.insert(PathsToNested::FORCED_AMBIGUITY);
}
}
}

self
}

#[must_use]
fn extend_with_paths(self, path: PathsToNested) -> Self {
let mut new = PathsToNested::empty();
for p in path.iter_paths() {
new |= self.extend_with(p);
}
new
}

fn iter_paths(self) -> impl Iterator<Item = PathKind> {
let (PathKind::Inductive
| PathKind::Unknown
| PathKind::Coinductive
| PathKind::ForcedAmbiguity);
[PathKind::Inductive, PathKind::Unknown, PathKind::Coinductive, PathKind::ForcedAmbiguity]
.into_iter()
.filter(move |&p| self.contains(p.into()))
}
}

/// The nested goals of each stack entry and the path from the
Expand Down Expand Up @@ -693,7 +686,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
if let Some((_scope, expected)) = validate_cache {
// Do not try to move a goal into the cache again if we're testing
// the global cache.
assert_eq!(evaluation_result.result, expected, "input={input:?}");
assert_eq!(expected, evaluation_result.result, "input={input:?}");
} else if D::inspect_is_noop(inspect) {
self.insert_global_cache(cx, input, evaluation_result, dep_node)
}
Expand Down Expand Up @@ -763,14 +756,11 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
/// provisional cache entry is still applicable. We need to keep the cache entries to
/// prevent hangs.
///
/// What we therefore do is check whether the cycle kind of all cycles the goal of a
/// provisional cache entry is involved in would stay the same when computing the
/// goal without its cycle head on the stack. For more details, see the relevant
/// This can be thought of as pretending to reevaluate the popped head as nested goals
/// of this provisional result. For this to be correct, all cycles encountered while
/// we'd reevaluate the cycle head as a nested goal must keep the same cycle kind.
/// [rustc-dev-guide chapter](https://rustc-dev-guide.rust-lang.org/solve/caching.html).
///
/// This can be thought of rotating the sub-tree of this provisional result and changing
/// its entry point while making sure that all paths through this sub-tree stay the same.
///
/// In case the popped cycle head failed to reach a fixpoint anything which depends on
/// its provisional result is invalid. Actually discarding provisional cache entries in
/// this case would cause hangs, so we instead change the result of dependant provisional
Expand All @@ -782,7 +772,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
stack_entry: &StackEntry<X>,
mut mutate_result: impl FnMut(X::Input, X::Result) -> X::Result,
) {
let head = self.stack.next_index();
let popped_head = self.stack.next_index();
#[allow(rustc::potential_query_instability)]
self.provisional_cache.retain(|&input, entries| {
entries.retain_mut(|entry| {
Expand All @@ -792,30 +782,44 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
path_from_head,
result,
} = entry;
if heads.highest_cycle_head() == head {
let ep = if heads.highest_cycle_head() == popped_head {
heads.remove_highest_cycle_head()
} else {
return true;
}

// We only try to rebase if all paths from the cache entry
// to its heads are coinductive. In this case these cycle
// kinds won't change, no matter the goals between these
// heads and the provisional cache entry.
if heads.iter().any(|(_, p)| matches!(p, AllPathsToHeadCoinductive::No)) {
return false;
}
};

// The same for nested goals of the cycle head.
if stack_entry.heads.iter().any(|(_, p)| matches!(p, AllPathsToHeadCoinductive::No))
{
return false;
// We're rebasing an entry `e` over a head `p`. This head
// has a number of own heads `h` it depends on. We need to
// make sure that the path kind of all paths `hph` remain the
// same after rebasing.
//
// After rebasing the cycles `hph` will go through `e`. We need to make
// sure that forall possible paths `hep`, `heph` is equal to `hph.`
for (h, ph) in stack_entry.heads.iter() {
let hp =
Self::cycle_path_kind(&self.stack, stack_entry.step_kind_from_parent, h);

// We first validate that all cycles while computing `p` would stay
// the same if we were to recompute it as a nested goal of `e`.
let he = hp.extend(*path_from_head);
for ph in ph.iter_paths() {
let hph = hp.extend(ph);
for ep in ep.iter_paths() {
let hep = ep.extend(he);
let heph = hep.extend(ph);
if hph != heph {
return false;
}
}
}

// If so, all paths reached while computing `p` have to get added
// the heads of `e` to make sure that rebasing `e` again also considers
// them.
let eph = ep.extend_with_paths(ph);
heads.insert(h, eph);
}

// Merge the cycle heads of the provisional cache entry and the
// popped head. If the popped cycle head was a root, discard all
// provisional cache entries which depend on it.
heads.merge(&stack_entry.heads);
let Some(head) = heads.opt_highest_cycle_head() else {
return false;
};
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_type_ir/src/search_graph/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::{Index, IndexMut};
use derive_where::derive_where;
use rustc_index::IndexVec;

use super::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind};
use crate::search_graph::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind};

rustc_index::newtype_index! {
#[orderable]
Expand Down Expand Up @@ -79,6 +79,9 @@ impl<X: Cx> Stack<X> {
}

pub(super) fn push(&mut self, entry: StackEntry<X>) -> StackDepth {
if cfg!(debug_assertions) && self.entries.iter().any(|e| e.input == entry.input) {
panic!("pushing duplicate entry on stack: {entry:?} {:?}", self.entries);
}
self.entries.push(entry)
}

Expand Down
Loading