Skip to content

Commit 22cb6f1

Browse files
committed
AST: Introduce ProtocolDecl::get{AssociatedType,ProtocolRequirement}()
1 parent 7f6ef1e commit 22cb6f1

20 files changed

+114
-196
lines changed

include/swift/AST/Decl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4195,6 +4195,14 @@ class ProtocolDecl final : public NominalTypeDecl {
41954195
/// a protocol having nested types (ObjC protocols).
41964196
llvm::TinyPtrVector<AssociatedTypeDecl *> getAssociatedTypeMembers() const;
41974197

4198+
/// Returns a protocol requirement with the given name, or nullptr if the
4199+
/// name has multiple overloads, or no overloads at all.
4200+
ValueDecl *getSingleRequirement(DeclName name) const;
4201+
4202+
/// Returns an associated type with the given name, or nullptr if one does
4203+
/// not exist.
4204+
AssociatedTypeDecl *getAssociatedType(Identifier name) const;
4205+
41984206
/// Walk this protocol and all of the protocols inherited by this protocol,
41994207
/// transitively, invoking the callback function for each protocol.
42004208
///

lib/AST/ASTDemangler.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,7 @@ TypeDecl *ASTBuilder::createTypeDecl(NodePointer node) {
8484
return nullptr;
8585

8686
auto name = Ctx.getIdentifier(node->getChild(1)->getText());
87-
auto results = proto->lookupDirect(name);
88-
if (results.size() != 1)
89-
return nullptr;
90-
91-
return dyn_cast<AssociatedTypeDecl>(results[0]);
87+
return proto->getAssociatedType(name);
9288
}
9389

9490
auto *DC = findDeclContext(node);
@@ -585,13 +581,8 @@ Type ASTBuilder::createDependentMemberType(StringRef member,
585581
if (!base->isTypeParameter())
586582
return Type();
587583

588-
auto flags = OptionSet<NominalTypeDecl::LookupDirectFlags>();
589-
flags |= NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions;
590-
for (auto member : protocol->lookupDirect(Ctx.getIdentifier(member),
591-
flags)) {
592-
if (auto assocType = dyn_cast<AssociatedTypeDecl>(member))
593-
return DependentMemberType::get(base, assocType);
594-
}
584+
if (auto assocType = protocol->getAssociatedType(Ctx.getIdentifier(member)))
585+
return DependentMemberType::get(base, assocType);
595586

596587
return Type();
597588
}

lib/AST/Decl.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4364,6 +4364,34 @@ ProtocolDecl::getAssociatedTypeMembers() const {
43644364
return result;
43654365
}
43664366

4367+
ValueDecl *ProtocolDecl::getSingleRequirement(DeclName name) const {
4368+
auto results = const_cast<ProtocolDecl *>(this)->lookupDirect(name);
4369+
ValueDecl *result = nullptr;
4370+
for (auto candidate : results) {
4371+
if (candidate->getDeclContext() != this ||
4372+
!candidate->isProtocolRequirement())
4373+
continue;
4374+
if (result) {
4375+
// Multiple results.
4376+
return nullptr;
4377+
}
4378+
result = candidate;
4379+
}
4380+
4381+
return result;
4382+
}
4383+
4384+
AssociatedTypeDecl *ProtocolDecl::getAssociatedType(Identifier name) const {
4385+
auto results = const_cast<ProtocolDecl *>(this)->lookupDirect(name);
4386+
for (auto candidate : results) {
4387+
if (candidate->getDeclContext() == this &&
4388+
isa<AssociatedTypeDecl>(candidate)) {
4389+
return cast<AssociatedTypeDecl>(candidate);
4390+
}
4391+
}
4392+
return nullptr;
4393+
}
4394+
43674395
Type ProtocolDecl::getSuperclass() const {
43684396
ASTContext &ctx = getASTContext();
43694397
return evaluateOrDefault(ctx.evaluator,

lib/AST/ProtocolConformance.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,7 @@ ProtocolConformanceRef::getTypeWitnessByName(Type type, Identifier name) const {
162162

163163
// Find the named requirement.
164164
ProtocolDecl *proto = getRequirement();
165-
AssociatedTypeDecl *assocType = nullptr;
166-
auto members = proto->lookupDirect(name,
167-
NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions);
168-
for (auto member : members) {
169-
assocType = dyn_cast<AssociatedTypeDecl>(member);
170-
if (assocType)
171-
break;
172-
}
165+
auto *assocType = proto->getAssociatedType(name);
173166

174167
// FIXME: Shouldn't this be a hard error?
175168
if (!assocType)
@@ -183,16 +176,7 @@ ConcreteDeclRef
183176
ProtocolConformanceRef::getWitnessByName(Type type, DeclName name) const {
184177
// Find the named requirement.
185178
auto *proto = getRequirement();
186-
auto results =
187-
proto->lookupDirect(name,
188-
NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions);
189-
190-
ValueDecl *requirement = nullptr;
191-
for (auto *result : results) {
192-
if (isa<ProtocolDecl>(result->getDeclContext()))
193-
requirement = result;
194-
}
195-
179+
auto *requirement = proto->getSingleRequirement(name);
196180
if (requirement == nullptr)
197181
return ConcreteDeclRef();
198182

lib/SILGen/SILGen.cpp

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,9 @@ FuncDecl *SILGenModule::getBridgeToObjectiveCRequirement(SILLocation loc) {
219219

220220
// Look for _bridgeToObjectiveC().
221221
auto &ctx = getASTContext();
222-
FuncDecl *found = nullptr;
223222
DeclName name(ctx, ctx.Id_bridgeToObjectiveC, llvm::ArrayRef<Identifier>());
224-
auto flags = OptionSet<NominalTypeDecl::LookupDirectFlags>();
225-
flags |= NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions;
226-
for (auto member : proto->lookupDirect(name, flags)) {
227-
if (auto func = dyn_cast<FuncDecl>(member)) {
228-
found = func;
229-
break;
230-
}
231-
}
223+
auto *found = dyn_cast_or_null<FuncDecl>(
224+
proto->getSingleRequirement(name));
232225

233226
if (!found)
234227
diagnose(loc, diag::bridging_objcbridgeable_broken, name);
@@ -251,17 +244,10 @@ FuncDecl *SILGenModule::getUnconditionallyBridgeFromObjectiveCRequirement(
251244

252245
// Look for _bridgeToObjectiveC().
253246
auto &ctx = getASTContext();
254-
FuncDecl *found = nullptr;
255247
DeclName name(ctx, ctx.getIdentifier("_unconditionallyBridgeFromObjectiveC"),
256248
llvm::makeArrayRef(Identifier()));
257-
auto flags = OptionSet<NominalTypeDecl::LookupDirectFlags>();
258-
flags |= NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions;
259-
for (auto member : proto->lookupDirect(name, flags)) {
260-
if (auto func = dyn_cast<FuncDecl>(member)) {
261-
found = func;
262-
break;
263-
}
264-
}
249+
auto *found = dyn_cast_or_null<FuncDecl>(
250+
proto->getSingleRequirement(name));
265251

266252
if (!found)
267253
diagnose(loc, diag::bridging_objcbridgeable_broken, name);
@@ -284,19 +270,9 @@ SILGenModule::getBridgedObjectiveCTypeRequirement(SILLocation loc) {
284270

285271
// Look for _bridgeToObjectiveC().
286272
auto &ctx = getASTContext();
287-
AssociatedTypeDecl *found = nullptr;
288-
DeclName name(ctx.Id_ObjectiveCType);
289-
auto flags = OptionSet<NominalTypeDecl::LookupDirectFlags>();
290-
flags |= NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions;
291-
for (auto member : proto->lookupDirect(name, flags)) {
292-
if (auto assocType = dyn_cast<AssociatedTypeDecl>(member)) {
293-
found = assocType;
294-
break;
295-
}
296-
}
297-
273+
auto *found = proto->getAssociatedType(ctx.Id_ObjectiveCType);
298274
if (!found)
299-
diagnose(loc, diag::bridging_objcbridgeable_broken, name);
275+
diagnose(loc, diag::bridging_objcbridgeable_broken, ctx.Id_ObjectiveCType);
300276

301277
BridgedObjectiveCType = found;
302278
return found;
@@ -337,15 +313,8 @@ VarDecl *SILGenModule::getNSErrorRequirement(SILLocation loc) {
337313

338314
// Look for _nsError.
339315
auto &ctx = getASTContext();
340-
VarDecl *found = nullptr;
341-
auto flags = OptionSet<NominalTypeDecl::LookupDirectFlags>();
342-
flags |= NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions;
343-
for (auto member : proto->lookupDirect(ctx.Id_nsError, flags)) {
344-
if (auto var = dyn_cast<VarDecl>(member)) {
345-
found = var;
346-
break;
347-
}
348-
}
316+
auto *found = dyn_cast_or_null<VarDecl>(
317+
proto->getSingleRequirement(ctx.Id_nsError));
349318

350319
NSErrorRequirement = found;
351320
return found;

lib/SILGen/SILGenExpr.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3019,7 +3019,8 @@ getOrCreateKeyPathEqualsAndHash(SILGenModule &SGM,
30193019
// Compare each pair of index values using the == witness from the
30203020
// conformance.
30213021
auto equatableProtocol = C.getProtocol(KnownProtocolKind::Equatable);
3022-
auto equalsMethod = equatableProtocol->lookupDirect(C.Id_EqualsOperator)[0];
3022+
auto equalsMethod = equatableProtocol->getSingleRequirement(
3023+
C.Id_EqualsOperator);
30233024
auto equalsRef = SILDeclRef(equalsMethod);
30243025
auto equalsTy = subSGF.SGM.Types.getConstantType(equalsRef);
30253026

@@ -3196,7 +3197,7 @@ getOrCreateKeyPathEqualsAndHash(SILGenModule &SGM,
31963197
}
31973198

31983199
VarDecl *hashValueVar =
3199-
cast<VarDecl>(hashableProto->lookupDirect(C.Id_hashValue)[0]);
3200+
cast<VarDecl>(hashableProto->getSingleRequirement(C.Id_hashValue));
32003201

32013202
auto formalTy = index.FormalType;
32023203
auto hashable = index.Hashable;

lib/SILOptimizer/IPO/EagerSpecializer.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,10 @@ void EagerDispatch::emitRefCountedObjectCheck(SILBasicBlock *FailedTypeCheckBB,
580580
{GenericMT});
581581
// Extract the i1 from the Bool struct.
582582
StructDecl *BoolStruct = cast<StructDecl>(Ctx.getBoolDecl());
583-
auto Members = BoolStruct->lookupDirect(Ctx.Id_value_);
583+
auto Members = BoolStruct->getStoredProperties();
584584
assert(Members.size() == 1 &&
585585
"Bool should have only one property with name '_value'");
586-
auto Member = dyn_cast<VarDecl>(Members[0]);
587-
assert(Member &&"Bool should have a property with name '_value' of type Int1");
586+
auto Member = Members[0];
588587
auto BoolValue =
589588
Builder.emitStructExtract(Loc, IsClassRuntimeCheck, Member, BoolTy);
590589
Builder.createCondBranch(Loc, BoolValue, SuccessBB, FailedTypeCheckBB);

lib/SILOptimizer/Transforms/Outliner.cpp

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -142,18 +142,10 @@ static SILDeclRef getBridgeToObjectiveC(CanType NativeType,
142142
return SILDeclRef();
143143

144144
auto Conformance = ConformanceRef->getConcrete();
145-
FuncDecl *Requirement = nullptr;
146145
// bridgeToObjectiveC
147146
DeclName Name(Ctx, Ctx.Id_bridgeToObjectiveC, llvm::ArrayRef<Identifier>());
148-
auto flags = OptionSet<NominalTypeDecl::LookupDirectFlags>();
149-
flags |= NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions;
150-
for (auto Member : Proto->lookupDirect(Name, flags)) {
151-
if (auto Func = dyn_cast<FuncDecl>(Member)) {
152-
Requirement = Func;
153-
break;
154-
}
155-
}
156-
assert(Requirement);
147+
auto *Requirement = dyn_cast_or_null<FuncDecl>(
148+
Proto->getSingleRequirement(Name));
157149
if (!Requirement)
158150
return SILDeclRef();
159151

@@ -173,19 +165,11 @@ SILDeclRef getBridgeFromObjectiveC(CanType NativeType,
173165
if (!ConformanceRef)
174166
return SILDeclRef();
175167
auto Conformance = ConformanceRef->getConcrete();
176-
FuncDecl *Requirement = nullptr;
177168
// _unconditionallyBridgeFromObjectiveC
178169
DeclName Name(Ctx, Ctx.getIdentifier("_unconditionallyBridgeFromObjectiveC"),
179170
llvm::makeArrayRef(Identifier()));
180-
auto flags = OptionSet<NominalTypeDecl::LookupDirectFlags>();
181-
flags |= NominalTypeDecl::LookupDirectFlags::IgnoreNewExtensions;
182-
for (auto Member : Proto->lookupDirect(Name, flags)) {
183-
if (auto Func = dyn_cast<FuncDecl>(Member)) {
184-
Requirement = Func;
185-
break;
186-
}
187-
}
188-
assert(Requirement);
171+
auto *Requirement = dyn_cast_or_null<FuncDecl>(
172+
Proto->getSingleRequirement(Name));
189173
if (!Requirement)
190174
return SILDeclRef();
191175

lib/SILOptimizer/Utils/CastOptimizer.cpp

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -487,26 +487,6 @@ findBridgeToObjCFunc(SILOptFunctionBuilder &functionBuilder,
487487
(void)conf;
488488

489489
// Generate code to invoke _bridgeToObjectiveC
490-
491-
auto *ntd = sourceType.getNominalOrBoundGenericNominal();
492-
assert(ntd);
493-
auto members = ntd->lookupDirect(mod.getASTContext().Id_bridgeToObjectiveC);
494-
if (members.empty()) {
495-
SmallVector<ValueDecl *, 4> foundMembers;
496-
if (ntd->getDeclContext()->lookupQualified(
497-
ntd, mod.getASTContext().Id_bridgeToObjectiveC,
498-
NLOptions::NL_ProtocolMembers, foundMembers)) {
499-
// Returned members are starting with the most specialized ones.
500-
// Thus, the first element is what we are looking for.
501-
members.push_back(foundMembers.front());
502-
}
503-
}
504-
505-
// There should be exactly one implementation of _bridgeToObjectiveC.
506-
if (members.size() != 1)
507-
return None;
508-
509-
auto bridgeFuncDecl = members.front();
510490
ModuleDecl *modDecl =
511491
mod.getASTContext().getLoadedModule(mod.getASTContext().Id_Foundation);
512492
if (!modDecl)
@@ -532,7 +512,7 @@ findBridgeToObjCFunc(SILOptFunctionBuilder &functionBuilder,
532512

533513
// Get substitutions, if source is a bound generic type.
534514
auto subMap = sourceType->getContextSubstitutionMap(
535-
mod.getSwiftModule(), bridgeFuncDecl->getDeclContext());
515+
mod.getSwiftModule(), resultDecl->getDeclContext());
536516

537517
// Implementation of _bridgeToObjectiveC could not be found.
538518
if (!bridgedFunc)

lib/Sema/CSApply.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,16 +2127,15 @@ namespace {
21272127
return witness;
21282128
};
21292129

2130-
auto associatedTypeArray =
2130+
auto *interpolationProto =
21312131
tc.getProtocol(expr->getLoc(),
2132-
KnownProtocolKind::ExpressibleByStringInterpolation)
2133-
->lookupDirect(tc.Context.Id_StringInterpolation);
2134-
if (associatedTypeArray.empty()) {
2132+
KnownProtocolKind::ExpressibleByStringInterpolation);
2133+
auto associatedTypeDecl = interpolationProto->getAssociatedType(
2134+
tc.Context.Id_StringInterpolation);
2135+
if (associatedTypeDecl == nullptr) {
21352136
tc.diagnose(expr->getStartLoc(), diag::interpolation_broken_proto);
21362137
return nullptr;
21372138
}
2138-
auto associatedTypeDecl =
2139-
cast<AssociatedTypeDecl>(associatedTypeArray.front());
21402139
auto interpolationType =
21412140
simplifyType(DependentMemberType::get(openedType, associatedTypeDecl));
21422141

0 commit comments

Comments
 (0)