Skip to content

Commit 97b5a0d

Browse files
committed
[Type checker] Stop devirtualizing the reference to IteratorProtocol.next().
Rather than having the type checker look for the specific witness to next() when type checking the for-each loop, which had the effect of devirtualizing next() even when it shouldn't be, leave the formation of the next() reference to SILGen. There, form it as a witness reference, so that the SIL optimizer can choose whether to devirtualization (or not).
1 parent f22a7e2 commit 97b5a0d

File tree

10 files changed

+43
-69
lines changed

10 files changed

+43
-69
lines changed

include/swift/AST/Stmt.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,6 @@ class ForEachStmt : public LabeledStmt {
808808

809809
// Set by Sema:
810810
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
811-
ConcreteDeclRef iteratorNext;
812811
VarDecl *iteratorVar = nullptr;
813812
Expr *iteratorVarRef = nullptr;
814813
OpaqueValueExpr *elementExpr = nullptr;
@@ -837,9 +836,6 @@ class ForEachStmt : public LabeledStmt {
837836
void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; }
838837
Expr *getConvertElementExpr() const { return convertElementExpr; }
839838

840-
void setIteratorNext(ConcreteDeclRef declRef) { iteratorNext = declRef; }
841-
ConcreteDeclRef getIteratorNext() const { return iteratorNext; }
842-
843839
void setSequenceConformance(ProtocolConformanceRef conformance) {
844840
sequenceConformance = conformance;
845841
}

lib/AST/ASTDumper.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,9 +1593,6 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
15931593
}
15941594
void visitForEachStmt(ForEachStmt *S) {
15951595
printCommon(S, "for_each_stmt");
1596-
PrintWithColorRAII(OS, LiteralValueColor) << " next=";
1597-
S->getIteratorNext().dump(
1598-
PrintWithColorRAII(OS, LiteralValueColor).getOS());
15991596
OS << '\n';
16001597
printRec(S->getPattern());
16011598
OS << '\n';

lib/SILGen/SILGenStmt.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,14 @@ void StmtEmitter::visitRepeatWhileStmt(RepeatWhileStmt *S) {
903903
}
904904

905905
void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
906+
// Dig out information about the sequence conformance.
907+
auto sequenceConformance = S->getSequenceConformance();
908+
Type sequenceType = S->getSequence()->getType();
909+
auto sequenceProto =
910+
SGF.getASTContext().getProtocol(KnownProtocolKind::Sequence);
911+
auto sequenceSubs = SubstitutionMap::getProtocolSubstitutions(
912+
sequenceProto, sequenceType, sequenceConformance);
913+
906914
// Emit the 'iterator' variable that we'll be using for iteration.
907915
LexicalScope OuterForScope(SGF, CleanupLocation(S));
908916
{
@@ -912,12 +920,6 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
912920

913921
// Compute the reference to the Sequence's makeIterator().
914922
FuncDecl *makeIteratorReq = SGF.getASTContext().getSequenceMakeIterator();
915-
auto sequenceProto =
916-
SGF.getASTContext().getProtocol(KnownProtocolKind::Sequence);
917-
auto sequenceConformance = S->getSequenceConformance();
918-
Type sequenceType = S->getSequence()->getType();
919-
auto sequenceSubs = SubstitutionMap::getProtocolSubstitutions(
920-
sequenceProto, sequenceType, sequenceConformance);
921923
ConcreteDeclRef makeIteratorRef(makeIteratorReq, sequenceSubs);
922924

923925
// Call makeIterator().
@@ -964,8 +966,26 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
964966
JumpDest endDest = createJumpDest(S->getBody());
965967
SGF.BreakContinueDestStack.push_back({ S, endDest, loopDest });
966968

969+
// Compute the reference to the the iterator's next().
970+
auto iteratorProto =
971+
SGF.getASTContext().getProtocol(KnownProtocolKind::IteratorProtocol);
972+
ValueDecl *iteratorNextReq = iteratorProto->getSingleRequirement(
973+
DeclName(SGF.getASTContext(), SGF.getASTContext().Id_next,
974+
ArrayRef<Identifier>()));
975+
auto iteratorAssocType =
976+
sequenceProto->getAssociatedType(SGF.getASTContext().Id_Iterator);
977+
auto iteratorMemberRef = DependentMemberType::get(
978+
sequenceProto->getSelfInterfaceType(), iteratorAssocType);
979+
auto iteratorType = sequenceConformance.getAssociatedType(
980+
sequenceType, iteratorMemberRef);
981+
auto iteratorConformance = sequenceConformance.getAssociatedConformance(
982+
sequenceType, iteratorMemberRef, iteratorProto);
983+
auto iteratorSubs = SubstitutionMap::getProtocolSubstitutions(
984+
iteratorProto, iteratorType, iteratorConformance);
985+
ConcreteDeclRef iteratorNextRef(iteratorNextReq, iteratorSubs);
986+
967987
auto buildArgumentSource = [&]() {
968-
if (cast<FuncDecl>(S->getIteratorNext().getDecl())->getSelfAccessKind() ==
988+
if (cast<FuncDecl>(iteratorNextRef.getDecl())->getSelfAccessKind() ==
969989
SelfAccessKind::Mutating) {
970990
LValue lv =
971991
SGF.emitLValue(S->getIteratorVarRef(), SGFAccessKind::ReadWrite);
@@ -981,7 +1001,7 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
9811001
auto buildElementRValue = [&](SILLocation loc, SGFContext ctx) {
9821002
RValue result;
9831003
result = SGF.emitApplyMethod(
984-
loc, S->getIteratorNext(), buildArgumentSource(),
1004+
loc, iteratorNextRef, buildArgumentSource(),
9851005
PreparedArguments(ArrayRef<AnyFunctionType::Param>({})),
9861006
S->getElementExpr() ? SGFContext() : ctx);
9871007
if (S->getElementExpr()) {

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,12 +2926,6 @@ auto TypeChecker::typeCheckForEachBinding(
29262926
/// The type of the iterator.
29272927
Type IteratorType;
29282928

2929-
/// The conformance of the iterator type to IteratorProtocol.
2930-
ProtocolConformanceRef IteratorConformance;
2931-
2932-
/// The type of makeIterator.
2933-
Type MakeIteratorType;
2934-
29352929
public:
29362930
explicit BindingListener(ForEachStmt *stmt) : Stmt(stmt) { }
29372931

@@ -2993,10 +2987,11 @@ auto TypeChecker::typeCheckForEachBinding(
29932987
// Reference the makeIterator witness.
29942988
ASTContext &ctx = cs.getASTContext();
29952989
FuncDecl *makeIterator = ctx.getSequenceMakeIterator();
2996-
MakeIteratorType = cs.createTypeVariable(Locator, TVO_CanBindToNoEscape);
2990+
Type makeIteratorType =
2991+
cs.createTypeVariable(Locator, TVO_CanBindToNoEscape);
29972992
cs.addValueWitnessConstraint(
29982993
LValueType::get(SequenceType), makeIterator,
2999-
MakeIteratorType, cs.DC, FunctionRefKind::Compound,
2994+
makeIteratorType, cs.DC, FunctionRefKind::Compound,
30002995
ContextualLocator);
30012996

30022997
Stmt->setSequence(expr);
@@ -3037,23 +3032,12 @@ auto TypeChecker::typeCheckForEachBinding(
30373032
"Couldn't find sequence conformance");
30383033
Stmt->setSequenceConformance(SequenceConformance);
30393034

3040-
// Retrieve the conformance of the iterator type to IteratorProtocol.
3041-
// FIXME: We probably don't even need this. If we do, get it from
3042-
// SequenceConformance instead.
3043-
IteratorConformance = TypeChecker::conformsToProtocol(
3044-
IteratorType, IteratorProto, cs.DC,
3045-
ConformanceCheckFlags::InExpression,
3046-
expr->getLoc());
3047-
30483035
solution.setExprTypes(expr);
30493036
return expr;
30503037
}
30513038

30523039
ForEachBinding getBinding() const {
3053-
return {
3054-
SequenceType, SequenceConformance, IteratorType, IteratorConformance,
3055-
ElementType
3056-
};
3040+
return { SequenceType, SequenceConformance, IteratorType, ElementType };
30573041
}
30583042
};
30593043

lib/Sema/TypeCheckStmt.cpp

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
777777
// Invoke iterator() to get an iterator from the sequence.
778778
Type iteratorTy = binding->iteratorType;
779779
VarDecl *iterator;
780+
Type nextResultType = OptionalType::get(binding->elementType);
780781
{
781782
// Create a local variable to capture the iterator.
782783
std::string name;
@@ -797,7 +798,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
797798

798799
// TODO: test/DebugInfo/iteration.swift requires this extra info to
799800
// be around.
800-
auto nextResultType = OptionalType::get(binding->elementType);
801801
PatternBindingDecl::createImplicit(
802802
getASTContext(), StaticSpellingKind::None, genPat,
803803
new (getASTContext()) OpaqueValueExpr(S->getInLoc(), nextResultType),
@@ -808,33 +808,15 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
808808
if (TypeChecker::requireOptionalIntrinsics(getASTContext(), S->getForLoc()))
809809
return nullptr;
810810

811-
// Gather the witnesses from the Iterator protocol conformance, which
812-
// we'll use to drive the loop.
813-
// FIXME: Would like to customize the diagnostic emitted in
814-
// conformsToProtocol().
815-
auto genConformance = binding->iteratorConformance;
816-
if (genConformance.isInvalid())
817-
return nullptr;
818-
811+
// Create the iterator variable.
819812
auto *varRef = TypeChecker::buildCheckedRefExpr(iterator, DC,
820813
DeclNameLoc(S->getInLoc()),
821814
/*implicit*/ true);
822815
if (!varRef)
823816
return nullptr;
824-
825817
S->setIteratorVarRef(varRef);
826818

827-
auto witness =
828-
genConformance.getWitnessByName(iteratorTy, getASTContext().Id_next);
829-
if (!witness)
830-
return nullptr;
831-
S->setIteratorNext(witness);
832-
833-
auto nextResultType = cast<FuncDecl>(S->getIteratorNext().getDecl())
834-
->getResultInterfaceType()
835-
.subst(S->getIteratorNext().getSubstitutions());
836-
837-
// Convert that Optional<T> value to Optional<Element>.
819+
// Convert that Optional<Element> value to the type of the pattern.
838820
auto optPatternType = OptionalType::get(S->getPattern()->getType());
839821
if (!optPatternType->isEqual(nextResultType)) {
840822
OpaqueValueExpr *elementExpr =

lib/Sema/TypeChecker.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,6 @@ class TypeChecker final {
10371037
Type sequenceType;
10381038
ProtocolConformanceRef sequenceConformance;
10391039
Type iteratorType;
1040-
ProtocolConformanceRef iteratorConformance;
10411040
Type elementType;
10421041
};
10431042

test/Constraints/generic_protocol_witness.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,4 @@ struct L<T>: Sequence {} // expected-error {{type 'L<T>' does not conform to pro
6161
func z(_ x: L<Int>) {
6262
for xx in x {}
6363
// expected-warning@-1{{immutable value 'xx' was never used; consider replacing with '_' or removing it}}
64-
// expected-error@-2{{type 'L<Int>.Iterator' does not conform to protocol 'IteratorProtocol'}}
6564
}

test/SILGen/foreach.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ func trivialStructBreak(_ xx: [Int]) {
115115
// CHECK: [[LOOP_DEST]]:
116116
// CHECK: [[GET_ELT_STACK:%.*]] = alloc_stack $Optional<Int>
117117
// CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator<Array<Int>>
118-
// CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method)
119-
// CHECK: apply [[FUNC_REF]]<Array<Int>>([[GET_ELT_STACK]], [[WRITE]])
118+
// CHECK: [[FUNC_REF:%.*]] = witness_method $IndexingIterator<Array<Int>>, #IteratorProtocol.next!1 : <Self where Self : IteratorProtocol> (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element>
119+
// CHECK: apply [[FUNC_REF]]<IndexingIterator<Array<Int>>>([[GET_ELT_STACK]], [[WRITE]])
120120
// CHECK: [[IND_VAR:%.*]] = load [trivial] [[GET_ELT_STACK]]
121121
// CHECK: switch_enum [[IND_VAR]] : $Optional<Int>, case #Optional.some!enumelt.1: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]]
122122
//
@@ -215,8 +215,8 @@ func existentialBreak(_ xx: [P]) {
215215
//
216216
// CHECK: [[LOOP_DEST]]:
217217
// CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator<Array<P>>
218-
// CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method)
219-
// CHECK: apply [[FUNC_REF]]<Array<P>>([[ELT_STACK]], [[WRITE]])
218+
// CHECK: [[FUNC_REF:%.*]] = witness_method $IndexingIterator<Array<P>>, #IteratorProtocol.next!1 : <Self where Self : IteratorProtocol> (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element>
219+
// CHECK: apply [[FUNC_REF]]<IndexingIterator<Array<P>>>([[ELT_STACK]], [[WRITE]])
220220
// CHECK: switch_enum_addr [[ELT_STACK]] : $*Optional<P>, case #Optional.some!enumelt.1: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]]
221221
//
222222
// CHECK: [[SOME_BB]]:
@@ -375,8 +375,8 @@ func genericStructBreak<T>(_ xx: [GenericStruct<T>]) {
375375
//
376376
// CHECK: [[LOOP_DEST]]:
377377
// CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator<Array<GenericStruct<T>>>
378-
// CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method)
379-
// CHECK: apply [[FUNC_REF]]<Array<GenericStruct<T>>>([[ELT_STACK]], [[WRITE]])
378+
// CHECK: [[FUNC_REF:%.*]] = witness_method $IndexingIterator<Array<GenericStruct<T>>>, #IteratorProtocol.next!1 : <Self where Self : IteratorProtocol> (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element>
379+
// CHECK: apply [[FUNC_REF]]<IndexingIterator<Array<GenericStruct<T>>>>([[ELT_STACK]], [[WRITE]])
380380
// CHECK: switch_enum_addr [[ELT_STACK]] : $*Optional<GenericStruct<T>>, case #Optional.some!enumelt.1: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]]
381381
//
382382
// CHECK: [[SOME_BB]]:

test/SILGen/statements.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ func for_loops2() {
167167
// rdar://problem/19316670
168168
// CHECK: alloc_stack $Optional<MyClass>
169169
// CHECK-NEXT: [[WRITE:%.*]] = begin_access [modify] [unknown]
170-
// CHECK: [[NEXT:%[0-9]+]] = function_ref @$ss16IndexingIteratorV4next{{[_0-9a-zA-Z]*}}F
171-
// CHECK-NEXT: apply [[NEXT]]<Array<MyClass>>
170+
// CHECK: [[NEXT:%[0-9]+]] = witness_method $IndexingIterator<Array<MyClass>>, #IteratorProtocol.next!1 : <Self where Self : IteratorProtocol> (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element>
171+
// CHECK-NEXT: apply [[NEXT]]<IndexingIterator<Array<MyClass>>>
172172
// CHECK: class_method [[OBJ:%[0-9]+]] : $MyClass, #MyClass.foo!1
173173
let objects = [MyClass(), MyClass() ]
174174
for obj in objects {

test/stmt/foreach.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ struct BadContainer2 : Sequence { // expected-error{{type 'BadContainer2' does n
1515
func bad_containers_2(bc: BadContainer2) {
1616
for e in bc { }
1717
// expected-warning@-1 {{immutable value 'e' was never used; consider replacing with '_' or removing it}}
18-
// expected-error@-2{{type 'BadContainer2.Iterator' does not conform to protocol 'IteratorProtocol'}}
1918
}
2019

2120
struct BadContainer3 : Sequence { // expected-error{{type 'BadContainer3' does not conform to protocol 'Sequence'}}
@@ -25,7 +24,6 @@ struct BadContainer3 : Sequence { // expected-error{{type 'BadContainer3' does n
2524
func bad_containers_3(bc: BadContainer3) {
2625
for e in bc { }
2726
// expected-warning@-1 {{immutable value 'e' was never used; consider replacing with '_' or removing it}}
28-
// expected-error@-2{{type 'BadContainer3.Iterator' does not conform to protocol 'IteratorProtocol'}}
2927
}
3028

3129
struct BadIterator1 {}
@@ -38,7 +36,6 @@ struct BadContainer4 : Sequence { // expected-error{{type 'BadContainer4' does n
3836
func bad_containers_4(bc: BadContainer4) {
3937
for e in bc { }
4038
// expected-warning@-1 {{immutable value 'e' was never used; consider replacing with '_' or removing it}}
41-
// expected-error@-2{{type 'BadContainer4.Iterator' does not conform to protocol 'IteratorProtocol'}}
4239
}
4340

4441
// Pattern type-checking

0 commit comments

Comments
 (0)