Skip to content

Commit d2bcd64

Browse files
authored
Merge pull request #821 from WaffleLapkin/trait_upcast
Implement trait upcasting
2 parents c83151f + 5689335 commit d2bcd64

File tree

3 files changed

+350
-55
lines changed

3 files changed

+350
-55
lines changed

chalk-solve/src/clauses/builtin_traits/unsize.rs

Lines changed: 181 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::collections::HashSet;
22
use std::iter;
33
use std::ops::ControlFlow;
44

5+
use crate::clauses::super_traits::super_traits;
56
use crate::clauses::ClauseBuilder;
67
use crate::rust_ir::AdtKind;
78
use crate::{Interner, RustIrDatabase, TraitRef, WellKnownTrait};
@@ -136,17 +137,27 @@ fn uses_outer_binder_params<I: Interner>(
136137
matches!(flow, ControlFlow::Break(_))
137138
}
138139

139-
fn principal_id<I: Interner>(
140+
fn principal_trait_ref<I: Interner>(
140141
db: &dyn RustIrDatabase<I>,
141142
bounds: &Binders<QuantifiedWhereClauses<I>>,
142-
) -> Option<TraitId<I>> {
143-
let interner = db.interner();
144-
143+
) -> Option<Binders<Binders<TraitRef<I>>>> {
145144
bounds
146-
.skip_binders()
147-
.iter(interner)
148-
.filter_map(|b| b.trait_id())
149-
.find(|&id| !db.trait_datum(id).is_auto_trait())
145+
.map_ref(|b| b.iter(db.interner()))
146+
.into_iter()
147+
.find_map(|b| {
148+
b.filter_map(|qwc| {
149+
qwc.as_ref().filter_map(|wc| match wc {
150+
WhereClause::Implemented(trait_ref) => {
151+
if db.trait_datum(trait_ref.trait_id).is_auto_trait() {
152+
None
153+
} else {
154+
Some(trait_ref.clone())
155+
}
156+
}
157+
_ => None,
158+
})
159+
})
160+
})
150161
}
151162

152163
fn auto_trait_ids<'a, I: Interner>(
@@ -187,10 +198,10 @@ pub fn add_unsize_program_clauses<I: Interner>(
187198
// could be lifted.
188199
//
189200
// for more info visit `fn assemble_candidates_for_unsizing` and
190-
// `fn confirm_builtin_unisize_candidate` in rustc.
201+
// `fn confirm_builtin_unsize_candidate` in rustc.
191202

192203
match (source_ty.kind(interner), target_ty.kind(interner)) {
193-
// dyn Trait + AutoX + 'a -> dyn Trait + AutoY + 'b
204+
// dyn TraitA + AutoA + 'a -> dyn TraitB + AutoB + 'b
194205
(
195206
TyKind::Dyn(DynTy {
196207
bounds: bounds_a,
@@ -201,13 +212,33 @@ pub fn add_unsize_program_clauses<I: Interner>(
201212
lifetime: lifetime_b,
202213
}),
203214
) => {
204-
let principal_a = principal_id(db, bounds_a);
205-
let principal_b = principal_id(db, bounds_b);
215+
let principal_trait_ref_a = principal_trait_ref(db, bounds_a);
216+
let principal_a = principal_trait_ref_a
217+
.as_ref()
218+
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);
219+
let principal_b = principal_trait_ref(db, bounds_b)
220+
.map(|trait_ref| trait_ref.skip_binders().skip_binders().trait_id);
221+
222+
// Include super traits in a list of auto traits for A,
223+
// to allow `dyn Trait -> dyn Trait + X` if `Trait: X`.
224+
let auto_trait_ids_a: Vec<_> = auto_trait_ids(db, bounds_a)
225+
.chain(principal_a.into_iter().flat_map(|principal_a| {
226+
super_traits(db, principal_a)
227+
.into_value_and_skipped_binders()
228+
.0
229+
.0
230+
.into_iter()
231+
.map(|x| x.skip_binders().trait_id)
232+
.filter(|&x| db.trait_datum(x).is_auto_trait())
233+
}))
234+
.collect();
206235

207-
let auto_trait_ids_a: Vec<_> = auto_trait_ids(db, bounds_a).collect();
208236
let auto_trait_ids_b: Vec<_> = auto_trait_ids(db, bounds_b).collect();
209237

210-
let may_apply = principal_a == principal_b
238+
// If B has a principal, then A must as well
239+
// (i.e. we allow dropping principal, but not creating a principal out of thin air).
240+
// `AutoB` must be a subset of `AutoA`.
241+
let may_apply = principal_a.is_some() >= principal_b.is_some()
211242
&& auto_trait_ids_b
212243
.iter()
213244
.all(|id_b| auto_trait_ids_a.iter().any(|id_a| id_a == id_b));
@@ -216,6 +247,13 @@ pub fn add_unsize_program_clauses<I: Interner>(
216247
return;
217248
}
218249

250+
// Check that source lifetime outlives target lifetime
251+
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
252+
a: lifetime_a.clone(),
253+
b: lifetime_b.clone(),
254+
})
255+
.cast(interner);
256+
219257
// COMMENT FROM RUSTC:
220258
// ------------------
221259
// Require that the traits involved in this upcast are **equal**;
@@ -233,48 +271,138 @@ pub fn add_unsize_program_clauses<I: Interner>(
233271
// with what our behavior should be there. -nikomatsakis
234272
// ------------------
235273

236-
// Construct a new trait object type by taking the source ty,
237-
// filtering out auto traits of source that are not present in target
238-
// and changing source lifetime to target lifetime.
239-
//
240-
// In order for the coercion to be valid, this new type
241-
// should be equal to target type.
242-
let new_source_ty = TyKind::Dyn(DynTy {
243-
bounds: bounds_a.map_ref(|bounds| {
244-
QuantifiedWhereClauses::from_iter(
245-
interner,
246-
bounds.iter(interner).filter(|bound| {
247-
let trait_id = match bound.trait_id() {
248-
Some(id) => id,
249-
None => return true,
250-
};
251-
252-
if auto_trait_ids_a.iter().all(|&id_a| id_a != trait_id) {
253-
return true;
254-
}
255-
auto_trait_ids_b.iter().any(|&id_b| id_b == trait_id)
274+
if principal_a == principal_b || principal_b.is_none() {
275+
// Construct a new trait object type by taking the source ty,
276+
// replacing auto traits of source with those of target,
277+
// and changing source lifetime to target lifetime.
278+
//
279+
// In order for the coercion to be valid, this new type
280+
// should be equal to target type.
281+
let new_source_ty = TyKind::Dyn(DynTy {
282+
bounds: bounds_a.map_ref(|bounds| {
283+
QuantifiedWhereClauses::from_iter(
284+
interner,
285+
bounds
286+
.iter(interner)
287+
.cloned()
288+
.filter_map(|bound| {
289+
let Some(trait_id) = bound.trait_id() else {
290+
// Keep non-"implements" bounds as-is
291+
return Some(bound);
292+
};
293+
294+
// Auto traits are already checked above, ignore them
295+
// (we'll use the ones from B below)
296+
if db.trait_datum(trait_id).is_auto_trait() {
297+
return None;
298+
}
299+
300+
// The only "implements" bound that is not an auto trait, is the principal
301+
assert_eq!(Some(trait_id), principal_a);
302+
303+
// Only include principal_a if the principal_b is also present
304+
// (this allows dropping principal, `dyn Tr+A -> dyn A`)
305+
principal_b.is_some().then(|| bound)
306+
})
307+
// Add auto traits from B (again, they are already checked above).
308+
.chain(bounds_b.skip_binders().iter(interner).cloned().filter(
309+
|bound| {
310+
bound.trait_id().is_some_and(|trait_id| {
311+
db.trait_datum(trait_id).is_auto_trait()
312+
})
313+
},
314+
)),
315+
)
316+
}),
317+
lifetime: lifetime_b.clone(),
318+
})
319+
.intern(interner);
320+
321+
// Check that new source is equal to target
322+
let eq_goal = EqGoal {
323+
a: new_source_ty.cast(interner),
324+
b: target_ty.clone().cast(interner),
325+
}
326+
.cast(interner);
327+
328+
builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal].iter());
329+
} else {
330+
// Conditions above imply that both of these are always `Some`
331+
// (b != None, b is Some iff a is Some).
332+
let principal_a = principal_a.unwrap();
333+
let principal_b = principal_b.unwrap();
334+
335+
let principal_trait_ref_a = principal_trait_ref_a.unwrap();
336+
let applicable_super_traits = super_traits(db, principal_a)
337+
.map(|(super_trait_refs, _)| super_trait_refs)
338+
.into_iter()
339+
.filter(|trait_ref| {
340+
trait_ref.skip_binders().skip_binders().trait_id == principal_b
341+
});
342+
343+
for super_trait_ref in applicable_super_traits {
344+
// `super_trait_ref` is, at this point, quantified over generic params of
345+
// `principal_a` and relevant higher-ranked lifetimes that come from super
346+
// trait elaboration (see comments on `super_traits()`).
347+
//
348+
// So if we have `trait Trait<'a, T>: for<'b> Super<'a, 'b, T> {}`,
349+
// `super_trait_ref` can be something like
350+
// `for<Self, 'a, T> for<'b> Self: Super<'a, 'b, T>`.
351+
//
352+
// We need to convert it into a bound for `DynTy`. We do this by substituting
353+
// bound vars of `principal_trait_ref_a` and then fusing inner binders for
354+
// higher-ranked lifetimes.
355+
let rebound_super_trait_ref = principal_trait_ref_a.map_ref(|q_trait_ref_a| {
356+
q_trait_ref_a
357+
.map_ref(|trait_ref_a| {
358+
super_trait_ref.substitute(interner, &trait_ref_a.substitution)
359+
})
360+
.fuse_binders(interner)
361+
});
362+
363+
// Skip `for<Self>` binder. We'll rebind it immediately below.
364+
let new_principal_trait_ref = rebound_super_trait_ref
365+
.into_value_and_skipped_binders()
366+
.0
367+
.map(|it| it.cast(interner));
368+
369+
// Swap trait ref for `principal_a` with the new trait ref, drop the auto
370+
// traits not included in the upcast target.
371+
let new_source_ty = TyKind::Dyn(DynTy {
372+
bounds: bounds_a.map_ref(|bounds| {
373+
QuantifiedWhereClauses::from_iter(
374+
interner,
375+
bounds.iter(interner).cloned().filter_map(|bound| {
376+
let trait_id = match bound.trait_id() {
377+
Some(id) => id,
378+
None => return Some(bound),
379+
};
380+
381+
if principal_a == trait_id {
382+
Some(new_principal_trait_ref.clone())
383+
} else {
384+
auto_trait_ids_b.contains(&trait_id).then_some(bound)
385+
}
386+
}),
387+
)
256388
}),
257-
)
258-
}),
259-
lifetime: lifetime_b.clone(),
260-
})
261-
.intern(interner);
389+
lifetime: lifetime_b.clone(),
390+
})
391+
.intern(interner);
392+
393+
// Check that new source is equal to target
394+
let eq_goal = EqGoal {
395+
a: new_source_ty.cast(interner),
396+
b: target_ty.clone().cast(interner),
397+
}
398+
.cast(interner);
262399

263-
// Check that new source is equal to target
264-
let eq_goal = EqGoal {
265-
a: new_source_ty.cast(interner),
266-
b: target_ty.clone().cast(interner),
400+
// We don't push goal for `principal_b`'s object safety because it's implied by
401+
// `principal_a`'s object safety.
402+
builder
403+
.push_clause(trait_ref.clone(), [eq_goal, lifetime_outlives_goal.clone()]);
404+
}
267405
}
268-
.cast(interner);
269-
270-
// Check that source lifetime outlives target lifetime
271-
let lifetime_outlives_goal: Goal<I> = WhereClause::LifetimeOutlives(LifetimeOutlives {
272-
a: lifetime_a.clone(),
273-
b: lifetime_b.clone(),
274-
})
275-
.cast(interner);
276-
277-
builder.push_clause(trait_ref, [eq_goal, lifetime_outlives_goal].iter());
278406
}
279407

280408
// T -> dyn Trait + 'a

chalk-solve/src/clauses/super_traits.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,28 @@ pub(super) fn push_trait_super_clauses<I: Interner>(
7373
}
7474
}
7575

76-
fn super_traits<I: Interner>(
76+
/// Returns super-`TraitRef`s and super-`Projection`s that are quantified over the parameters of
77+
/// `trait_id` and relevant higher-ranked lifetimes. The outer `Binders` is for the former and the
78+
/// inner `Binders` is for the latter.
79+
///
80+
/// For example, given the following trait definitions and `C` as `trait_id`,
81+
///
82+
/// ```
83+
/// trait A<'a, T> {}
84+
/// trait B<'b, U> where Self: for<'x> A<'x, U> {}
85+
/// trait C<'c, V> where Self: B<'c, V> {}
86+
/// ```
87+
///
88+
/// returns the following quantified `TraitRef`s.
89+
///
90+
/// ```notrust
91+
/// for<Self, 'c, V> {
92+
/// for<'x> { Self: A<'x, V> }
93+
/// for<> { Self: B<'c, V> }
94+
/// for<> { Self: C<'c, V> }
95+
/// }
96+
/// ```
97+
pub(crate) fn super_traits<I: Interner>(
7798
db: &dyn RustIrDatabase<I>,
7899
trait_id: TraitId<I>,
79100
) -> Binders<(

0 commit comments

Comments
 (0)