Skip to content

Commit 7977643

Browse files
authored
Merge pull request swiftlang#28984 from DougGregor/de-de-virtualize-for-each
[Type checker] Stop devirtualizing the reference to IteratorProtocol.next()
2 parents c513e61 + 97b5a0d commit 7977643

File tree

12 files changed

+89
-92
lines changed

12 files changed

+89
-92
lines changed

include/swift/AST/Stmt.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,6 @@ class ForEachStmt : public LabeledStmt {
808808

809809
// Set by Sema:
810810
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
811-
ConcreteDeclRef makeIterator;
812-
ConcreteDeclRef iteratorNext;
813811
VarDecl *iteratorVar = nullptr;
814812
Expr *iteratorVarRef = nullptr;
815813
OpaqueValueExpr *elementExpr = nullptr;
@@ -838,12 +836,6 @@ class ForEachStmt : public LabeledStmt {
838836
void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; }
839837
Expr *getConvertElementExpr() const { return convertElementExpr; }
840838

841-
void setMakeIterator(ConcreteDeclRef declRef) { makeIterator = declRef; }
842-
ConcreteDeclRef getMakeIterator() const { return makeIterator; }
843-
844-
void setIteratorNext(ConcreteDeclRef declRef) { iteratorNext = declRef; }
845-
ConcreteDeclRef getIteratorNext() const { return iteratorNext; }
846-
847839
void setSequenceConformance(ProtocolConformanceRef conformance) {
848840
sequenceConformance = conformance;
849841
}

lib/AST/ASTDumper.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,12 +1593,6 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
15931593
}
15941594
void visitForEachStmt(ForEachStmt *S) {
15951595
printCommon(S, "for_each_stmt");
1596-
PrintWithColorRAII(OS, LiteralValueColor) << " make_generator=";
1597-
S->getMakeIterator().dump(
1598-
PrintWithColorRAII(OS, LiteralValueColor).getOS());
1599-
PrintWithColorRAII(OS, LiteralValueColor) << " next=";
1600-
S->getIteratorNext().dump(
1601-
PrintWithColorRAII(OS, LiteralValueColor).getOS());
16021596
OS << '\n';
16031597
printRec(S->getPattern());
16041598
OS << '\n';

lib/SILGen/SILGenStmt.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -903,14 +903,28 @@ void StmtEmitter::visitRepeatWhileStmt(RepeatWhileStmt *S) {
903903
}
904904

905905
void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
906+
// Dig out information about the sequence conformance.
907+
auto sequenceConformance = S->getSequenceConformance();
908+
Type sequenceType = S->getSequence()->getType();
909+
auto sequenceProto =
910+
SGF.getASTContext().getProtocol(KnownProtocolKind::Sequence);
911+
auto sequenceSubs = SubstitutionMap::getProtocolSubstitutions(
912+
sequenceProto, sequenceType, sequenceConformance);
913+
906914
// Emit the 'iterator' variable that we'll be using for iteration.
907915
LexicalScope OuterForScope(SGF, CleanupLocation(S));
908916
{
909917
auto initialization =
910918
SGF.emitInitializationForVarDecl(S->getIteratorVar(), false);
911919
SILLocation loc = SILLocation(S->getSequence());
920+
921+
// Compute the reference to the Sequence's makeIterator().
922+
FuncDecl *makeIteratorReq = SGF.getASTContext().getSequenceMakeIterator();
923+
ConcreteDeclRef makeIteratorRef(makeIteratorReq, sequenceSubs);
924+
925+
// Call makeIterator().
912926
RValue result = SGF.emitApplyMethod(
913-
loc, S->getMakeIterator(), ArgumentSource(S->getSequence()),
927+
loc, makeIteratorRef, ArgumentSource(S->getSequence()),
914928
PreparedArguments(ArrayRef<AnyFunctionType::Param>({})),
915929
SGFContext(initialization.get()));
916930
if (!result.isInContext()) {
@@ -952,8 +966,26 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
952966
JumpDest endDest = createJumpDest(S->getBody());
953967
SGF.BreakContinueDestStack.push_back({ S, endDest, loopDest });
954968

969+
// Compute the reference to the the iterator's next().
970+
auto iteratorProto =
971+
SGF.getASTContext().getProtocol(KnownProtocolKind::IteratorProtocol);
972+
ValueDecl *iteratorNextReq = iteratorProto->getSingleRequirement(
973+
DeclName(SGF.getASTContext(), SGF.getASTContext().Id_next,
974+
ArrayRef<Identifier>()));
975+
auto iteratorAssocType =
976+
sequenceProto->getAssociatedType(SGF.getASTContext().Id_Iterator);
977+
auto iteratorMemberRef = DependentMemberType::get(
978+
sequenceProto->getSelfInterfaceType(), iteratorAssocType);
979+
auto iteratorType = sequenceConformance.getAssociatedType(
980+
sequenceType, iteratorMemberRef);
981+
auto iteratorConformance = sequenceConformance.getAssociatedConformance(
982+
sequenceType, iteratorMemberRef, iteratorProto);
983+
auto iteratorSubs = SubstitutionMap::getProtocolSubstitutions(
984+
iteratorProto, iteratorType, iteratorConformance);
985+
ConcreteDeclRef iteratorNextRef(iteratorNextReq, iteratorSubs);
986+
955987
auto buildArgumentSource = [&]() {
956-
if (cast<FuncDecl>(S->getIteratorNext().getDecl())->getSelfAccessKind() ==
988+
if (cast<FuncDecl>(iteratorNextRef.getDecl())->getSelfAccessKind() ==
957989
SelfAccessKind::Mutating) {
958990
LValue lv =
959991
SGF.emitLValue(S->getIteratorVarRef(), SGFAccessKind::ReadWrite);
@@ -969,7 +1001,7 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
9691001
auto buildElementRValue = [&](SILLocation loc, SGFContext ctx) {
9701002
RValue result;
9711003
result = SGF.emitApplyMethod(
972-
loc, S->getIteratorNext(), buildArgumentSource(),
1004+
loc, iteratorNextRef, buildArgumentSource(),
9731005
PreparedArguments(ArrayRef<AnyFunctionType::Param>({})),
9741006
S->getElementExpr() ? SGFContext() : ctx);
9751007
if (S->getElementExpr()) {

lib/Sema/CSApply.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Solution::computeSubstitutions(GenericSignature sig,
8787
return ProtocolConformanceRef(protoType);
8888
}
8989

90+
// FIXME: Retrieve the conformance from the solution itself.
9091
return TypeChecker::conformsToProtocol(replacement, protoType,
9192
getConstraintSystem().DC,
9293
ConformanceCheckFlags::InExpression);
@@ -7345,6 +7346,37 @@ class SetExprTypes : public ASTWalker {
73457346
};
73467347
}
73477348

7349+
ProtocolConformanceRef Solution::resolveConformance(
7350+
ConstraintLocator *locator, ProtocolDecl *proto) {
7351+
for (const auto &conformance : Conformances) {
7352+
if (conformance.first != locator)
7353+
continue;
7354+
if (conformance.second.getRequirement() != proto)
7355+
continue;
7356+
7357+
// If the conformance doesn't require substitution, return it immediately.
7358+
auto conformanceRef = conformance.second;
7359+
if (conformanceRef.isAbstract())
7360+
return conformanceRef;
7361+
7362+
auto concrete = conformanceRef.getConcrete();
7363+
auto conformingType = concrete->getType();
7364+
if (!conformingType->hasTypeVariable())
7365+
return conformanceRef;
7366+
7367+
// Substitute into the conformance type, then look for a conformance
7368+
// again.
7369+
// FIXME: Should be able to perform the substitution using the Solution
7370+
// itself rather than another conforms-to-protocol check.
7371+
Type substConformingType = simplifyType(conformingType);
7372+
return TypeChecker::conformsToProtocol(
7373+
substConformingType, proto, constraintSystem->DC,
7374+
ConformanceCheckFlags::InExpression);
7375+
}
7376+
7377+
return ProtocolConformanceRef::forInvalid();
7378+
}
7379+
73487380
void Solution::setExprTypes(Expr *expr) const {
73497381
if (!expr)
73507382
return;

lib/Sema/ConstraintSystem.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,11 @@ class Solution {
898898
return None;
899899
}
900900

901+
/// Retrieve a fully-resolved protocol conformance at the given locator
902+
/// and with the given protocol.
903+
ProtocolConformanceRef resolveConformance(ConstraintLocator *locator,
904+
ProtocolDecl *proto);
905+
901906
void setExprTypes(Expr *expr) const;
902907

903908
SWIFT_DEBUG_DUMP;

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,12 +2926,6 @@ auto TypeChecker::typeCheckForEachBinding(
29262926
/// The type of the iterator.
29272927
Type IteratorType;
29282928

2929-
/// The conformance of the iterator type to IteratorProtocol.
2930-
ProtocolConformanceRef IteratorConformance;
2931-
2932-
/// The type of makeIterator.
2933-
Type MakeIteratorType;
2934-
29352929
public:
29362930
explicit BindingListener(ForEachStmt *stmt) : Stmt(stmt) { }
29372931

@@ -2993,10 +2987,11 @@ auto TypeChecker::typeCheckForEachBinding(
29932987
// Reference the makeIterator witness.
29942988
ASTContext &ctx = cs.getASTContext();
29952989
FuncDecl *makeIterator = ctx.getSequenceMakeIterator();
2996-
MakeIteratorType = cs.createTypeVariable(Locator, TVO_CanBindToNoEscape);
2990+
Type makeIteratorType =
2991+
cs.createTypeVariable(Locator, TVO_CanBindToNoEscape);
29972992
cs.addValueWitnessConstraint(
29982993
LValueType::get(SequenceType), makeIterator,
2999-
MakeIteratorType, cs.DC, FunctionRefKind::Compound,
2994+
makeIteratorType, cs.DC, FunctionRefKind::Compound,
30002995
ContextualLocator);
30012996

30022997
Stmt->setSequence(expr);
@@ -3013,17 +3008,8 @@ auto TypeChecker::typeCheckForEachBinding(
30133008

30143009
// Perform any necessary conversions of the sequence (e.g. [T]! -> [T]).
30153010
expr = solution.coerceToType(expr, SequenceType, Locator);
3016-
30173011
if (!expr) return nullptr;
30183012

3019-
// Convert the sequence as appropriate for the makeIterator() call.
3020-
auto makeIteratorOverload = solution.getOverloadChoice(ContextualLocator);
3021-
auto makeIteratorSelfType = solution.simplifyType(
3022-
makeIteratorOverload.openedFullType
3023-
)->castTo<AnyFunctionType>()->getParams()[0].getPlainType();
3024-
expr = solution.coerceToType(expr, makeIteratorSelfType,
3025-
ContextualLocator);
3026-
30273013
cs.cacheExprTypes(expr);
30283014
Stmt->setSequence(expr);
30293015

@@ -3040,39 +3026,18 @@ auto TypeChecker::typeCheckForEachBinding(
30403026
Stmt->setPattern(pattern);
30413027

30423028
// Get the conformance of the sequence type to the Sequence protocol.
3043-
// FIXME: Get this from the solution and substitute into that.
3044-
SequenceConformance = TypeChecker::conformsToProtocol(
3045-
SequenceType, SequenceProto, cs.DC,
3046-
ConformanceCheckFlags::InExpression,
3047-
expr->getLoc());
3029+
SequenceConformance = solution.resolveConformance(
3030+
ContextualLocator, SequenceProto);
30483031
assert(!SequenceConformance.isInvalid() &&
30493032
"Couldn't find sequence conformance");
30503033
Stmt->setSequenceConformance(SequenceConformance);
30513034

3052-
// Retrieve the conformance of the iterator type to IteratorProtocol.
3053-
// FIXME: Get this from the solution and substitute into that.
3054-
IteratorConformance = TypeChecker::conformsToProtocol(
3055-
IteratorType, IteratorProto, cs.DC,
3056-
ConformanceCheckFlags::InExpression,
3057-
expr->getLoc());
3058-
3059-
// Record the makeIterator declaration we used.
3060-
auto makeIteratorDecl = makeIteratorOverload.choice.getDecl();
3061-
auto makeIteratorSubs = SequenceType->getMemberSubstitutionMap(
3062-
cs.DC->getParentModule(), makeIteratorDecl);
3063-
auto makeIteratorDeclRef =
3064-
ConcreteDeclRef(makeIteratorDecl, makeIteratorSubs);
3065-
Stmt->setMakeIterator(makeIteratorDeclRef);
3066-
30673035
solution.setExprTypes(expr);
30683036
return expr;
30693037
}
30703038

30713039
ForEachBinding getBinding() const {
3072-
return {
3073-
SequenceType, SequenceConformance, IteratorType, IteratorConformance,
3074-
ElementType
3075-
};
3040+
return { SequenceType, SequenceConformance, IteratorType, ElementType };
30763041
}
30773042
};
30783043

lib/Sema/TypeCheckStmt.cpp

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
777777
// Invoke iterator() to get an iterator from the sequence.
778778
Type iteratorTy = binding->iteratorType;
779779
VarDecl *iterator;
780+
Type nextResultType = OptionalType::get(binding->elementType);
780781
{
781782
// Create a local variable to capture the iterator.
782783
std::string name;
@@ -797,7 +798,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
797798

798799
// TODO: test/DebugInfo/iteration.swift requires this extra info to
799800
// be around.
800-
auto nextResultType = OptionalType::get(binding->elementType);
801801
PatternBindingDecl::createImplicit(
802802
getASTContext(), StaticSpellingKind::None, genPat,
803803
new (getASTContext()) OpaqueValueExpr(S->getInLoc(), nextResultType),
@@ -808,33 +808,15 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
808808
if (TypeChecker::requireOptionalIntrinsics(getASTContext(), S->getForLoc()))
809809
return nullptr;
810810

811-
// Gather the witnesses from the Iterator protocol conformance, which
812-
// we'll use to drive the loop.
813-
// FIXME: Would like to customize the diagnostic emitted in
814-
// conformsToProtocol().
815-
auto genConformance = binding->iteratorConformance;
816-
if (genConformance.isInvalid())
817-
return nullptr;
818-
811+
// Create the iterator variable.
819812
auto *varRef = TypeChecker::buildCheckedRefExpr(iterator, DC,
820813
DeclNameLoc(S->getInLoc()),
821814
/*implicit*/ true);
822815
if (!varRef)
823816
return nullptr;
824-
825817
S->setIteratorVarRef(varRef);
826818

827-
auto witness =
828-
genConformance.getWitnessByName(iteratorTy, getASTContext().Id_next);
829-
if (!witness)
830-
return nullptr;
831-
S->setIteratorNext(witness);
832-
833-
auto nextResultType = cast<FuncDecl>(S->getIteratorNext().getDecl())
834-
->getResultInterfaceType()
835-
.subst(S->getIteratorNext().getSubstitutions());
836-
837-
// Convert that Optional<T> value to Optional<Element>.
819+
// Convert that Optional<Element> value to the type of the pattern.
838820
auto optPatternType = OptionalType::get(S->getPattern()->getType());
839821
if (!optPatternType->isEqual(nextResultType)) {
840822
OpaqueValueExpr *elementExpr =

lib/Sema/TypeChecker.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,6 @@ class TypeChecker final {
10371037
Type sequenceType;
10381038
ProtocolConformanceRef sequenceConformance;
10391039
Type iteratorType;
1040-
ProtocolConformanceRef iteratorConformance;
10411040
Type elementType;
10421041
};
10431042

test/Constraints/generic_protocol_witness.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,4 @@ struct L<T>: Sequence {} // expected-error {{type 'L<T>' does not conform to pro
6161
func z(_ x: L<Int>) {
6262
for xx in x {}
6363
// expected-warning@-1{{immutable value 'xx' was never used; consider replacing with '_' or removing it}}
64-
// expected-error@-2{{type 'L<Int>.Iterator' does not conform to protocol 'IteratorProtocol'}}
6564
}

test/SILGen/foreach.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ func trivialStructBreak(_ xx: [Int]) {
115115
// CHECK: [[LOOP_DEST]]:
116116
// CHECK: [[GET_ELT_STACK:%.*]] = alloc_stack $Optional<Int>
117117
// CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator<Array<Int>>
118-
// CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method)
119-
// CHECK: apply [[FUNC_REF]]<Array<Int>>([[GET_ELT_STACK]], [[WRITE]])
118+
// CHECK: [[FUNC_REF:%.*]] = witness_method $IndexingIterator<Array<Int>>, #IteratorProtocol.next!1 : <Self where Self : IteratorProtocol> (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element>
119+
// CHECK: apply [[FUNC_REF]]<IndexingIterator<Array<Int>>>([[GET_ELT_STACK]], [[WRITE]])
120120
// CHECK: [[IND_VAR:%.*]] = load [trivial] [[GET_ELT_STACK]]
121121
// CHECK: switch_enum [[IND_VAR]] : $Optional<Int>, case #Optional.some!enumelt.1: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]]
122122
//
@@ -215,8 +215,8 @@ func existentialBreak(_ xx: [P]) {
215215
//
216216
// CHECK: [[LOOP_DEST]]:
217217
// CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator<Array<P>>
218-
// CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method)
219-
// CHECK: apply [[FUNC_REF]]<Array<P>>([[ELT_STACK]], [[WRITE]])
218+
// CHECK: [[FUNC_REF:%.*]] = witness_method $IndexingIterator<Array<P>>, #IteratorProtocol.next!1 : <Self where Self : IteratorProtocol> (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element>
219+
// CHECK: apply [[FUNC_REF]]<IndexingIterator<Array<P>>>([[ELT_STACK]], [[WRITE]])
220220
// CHECK: switch_enum_addr [[ELT_STACK]] : $*Optional<P>, case #Optional.some!enumelt.1: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]]
221221
//
222222
// CHECK: [[SOME_BB]]:
@@ -375,8 +375,8 @@ func genericStructBreak<T>(_ xx: [GenericStruct<T>]) {
375375
//
376376
// CHECK: [[LOOP_DEST]]:
377377
// CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator<Array<GenericStruct<T>>>
378-
// CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method)
379-
// CHECK: apply [[FUNC_REF]]<Array<GenericStruct<T>>>([[ELT_STACK]], [[WRITE]])
378+
// CHECK: [[FUNC_REF:%.*]] = witness_method $IndexingIterator<Array<GenericStruct<T>>>, #IteratorProtocol.next!1 : <Self where Self : IteratorProtocol> (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element>
379+
// CHECK: apply [[FUNC_REF]]<IndexingIterator<Array<GenericStruct<T>>>>([[ELT_STACK]], [[WRITE]])
380380
// CHECK: switch_enum_addr [[ELT_STACK]] : $*Optional<GenericStruct<T>>, case #Optional.some!enumelt.1: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]]
381381
//
382382
// CHECK: [[SOME_BB]]:

0 commit comments

Comments
 (0)