Skip to content

Commit cff20a6

Browse files
committed
wip: de-asyncfn-ify iter macro
1 parent d46b2b9 commit cff20a6

File tree

5 files changed

+69
-6
lines changed

5 files changed

+69
-6
lines changed

compiler/rustc_const_eval/src/check_consts/check.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
591591
if let AggregateKind::Coroutine(def_id, ..) = kind.as_ref()
592592
&& let Some(
593593
coroutine_kind @ hir::CoroutineKind::Desugared(
594-
hir::CoroutineDesugaring::Async,
594+
hir::CoroutineDesugaring::Async | hir::CoroutineDesugaring::Gen,
595595
_,
596596
),
597597
) = self.tcx.coroutine_kind(def_id)

compiler/rustc_hir_typeck/src/closure.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
465465

466466
if let Some(trait_def_id) = trait_def_id {
467467
let found_kind = match closure_kind {
468-
hir::ClosureKind::Closure => self.tcx.fn_trait_kind_from_def_id(trait_def_id),
468+
hir::ClosureKind::Closure
469+
// FIXME(iter_macro): Someday we'll probably want iterator closures instead of
470+
// just using Fn* for iterators.
471+
| hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Gen) => {
472+
self.tcx.fn_trait_kind_from_def_id(trait_def_id)
473+
}
469474
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => self
470475
.tcx
471476
.async_fn_trait_kind_from_def_id(trait_def_id)

compiler/rustc_mir_transform/src/check_inline.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use rustc_middle::mir::{Body, TerminatorKind};
88
use rustc_middle::ty;
99
use rustc_middle::ty::TyCtxt;
1010
use rustc_span::sym;
11+
use tracing::debug;
1112

1213
use crate::pass_manager::MirLint;
1314

@@ -40,7 +41,9 @@ pub(super) fn is_inline_valid_on_fn<'tcx>(
4041
tcx: TyCtxt<'tcx>,
4142
def_id: DefId,
4243
) -> Result<(), &'static str> {
44+
debug!("is_inline_valid_on_fn({def_id:?})");
4345
let codegen_attrs = tcx.codegen_fn_attrs(def_id);
46+
debug!(" codegen_attrs: {codegen_attrs:?}");
4447
if tcx.has_attr(def_id, sym::rustc_no_mir_inline) {
4548
return Err("#[rustc_no_mir_inline]");
4649
}

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::ops::ControlFlow;
1111
use hir::LangItem;
1212
use hir::def_id::DefId;
1313
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
14-
use rustc_hir as hir;
14+
use rustc_hir::{self as hir, CoroutineDesugaring, CoroutineKind};
1515
use rustc_infer::traits::{Obligation, PolyTraitObligation, SelectionError};
1616
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
1717
use rustc_middle::ty::{self, Ty, TypeVisitableExt, TypingMode};
@@ -128,11 +128,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
128128
self.assemble_async_iterator_candidates(obligation, &mut candidates);
129129
} else if tcx.is_lang_item(def_id, LangItem::AsyncFnKindHelper) {
130130
self.assemble_async_fn_kind_helper_candidates(obligation, &mut candidates);
131+
} else if tcx.is_lang_item(def_id, LangItem::AsyncFn)
132+
|| tcx.is_lang_item(def_id, LangItem::AsyncFnOnce)
133+
|| tcx.is_lang_item(def_id, LangItem::AsyncFnMut)
134+
{
135+
self.assemble_async_closure_candidates(obligation, &mut candidates);
131136
}
132137

133138
// FIXME: Put these into `else if` blocks above, since they're built-in.
134139
self.assemble_closure_candidates(obligation, &mut candidates);
135-
self.assemble_async_closure_candidates(obligation, &mut candidates);
136140
self.assemble_fn_pointer_candidates(obligation, &mut candidates);
137141

138142
self.assemble_candidates_from_impls(obligation, &mut candidates);
@@ -428,6 +432,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
428432
}
429433
}
430434

435+
#[instrument(level = "debug", skip(self, candidates))]
431436
fn assemble_async_closure_candidates(
432437
&mut self,
433438
obligation: &PolyTraitObligation<'tcx>,
@@ -439,15 +444,33 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
439444
return;
440445
};
441446

447+
debug!("self_ty = {:?}", obligation.self_ty().skip_binder().kind());
442448
match *obligation.self_ty().skip_binder().kind() {
443-
ty::CoroutineClosure(_, args) => {
449+
ty::CoroutineClosure(def_id, args) => {
444450
if let Some(closure_kind) =
445451
args.as_coroutine_closure().kind_ty().to_opt_closure_kind()
446452
&& !closure_kind.extends(goal_kind)
447453
{
448454
return;
449455
}
450-
candidates.vec.push(AsyncClosureCandidate);
456+
457+
// Make sure this is actually an async closure.
458+
let Some(coroutine_kind) =
459+
self.tcx().coroutine_kind(self.tcx().coroutine_for_closure(def_id))
460+
else {
461+
bug!("coroutine with no kind");
462+
};
463+
464+
debug!(?coroutine_kind);
465+
match coroutine_kind {
466+
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
467+
candidates.vec.push(AsyncClosureCandidate);
468+
}
469+
// CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
470+
// candidates.vec.push(IteratorClosureCandidate)
471+
// }
472+
_ => (),
473+
}
451474
}
452475
// Closures and fn pointers implement `AsyncFn*` if their return types
453476
// implement `Future`, which is checked later.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//@ run-pass
2+
3+
#![feature(iter_macro, yield_expr)]
4+
5+
use std::iter::iter;
6+
7+
fn main() {
8+
let i = {
9+
let s = String::new();
10+
iter! { move || {
11+
yield s.len();
12+
for x in 5..10 {
13+
yield x * 2;
14+
}
15+
}}
16+
};
17+
test_iterator(i);
18+
}
19+
20+
/// Exercise the iterator in a separate function to ensure it's not capturing anything it shoudln't.
21+
fn test_iterator<I: Iterator<Item = usize>>(i: impl FnOnce() -> I) {
22+
let mut i = i();
23+
assert_eq!(i.next(), Some(0));
24+
assert_eq!(i.next(), Some(10));
25+
assert_eq!(i.next(), Some(12));
26+
assert_eq!(i.next(), Some(14));
27+
assert_eq!(i.next(), Some(16));
28+
assert_eq!(i.next(), Some(18));
29+
assert_eq!(i.next(), None);
30+
assert_eq!(i.next(), None);
31+
assert_eq!(i.next(), None);
32+
}

0 commit comments

Comments
 (0)