Skip to content

Commit b2ae546

Browse files
committed
RequirementMachine: Implement GenericSignature::lookupNestedType() query
This logic is mostly carried over from GenericSignatureBuilder::lookupNestedType().
1 parent 422ae0a commit b2ae546

File tree

4 files changed

+169
-9
lines changed

4 files changed

+169
-9
lines changed

include/swift/AST/RequirementMachine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class RequirementMachine final {
6666
bool areSameTypeParameterInContext(Type depType1, Type depType2) const;
6767
Type getCanonicalTypeInContext(Type type,
6868
TypeArrayView<GenericTypeParamType> genericParams) const;
69+
TypeDecl *lookupNestedType(Type depType, Identifier name) const;
6970

7071
void dump(llvm::raw_ostream &out) const;
7172
};

lib/AST/GenericSignature.cpp

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -883,15 +883,49 @@ TypeDecl *
883883
GenericSignatureImpl::lookupNestedType(Type type, Identifier name) const {
884884
assert(type->isTypeParameter());
885885

886-
auto *builder = getGenericSignatureBuilder();
887-
auto equivClass =
888-
builder->resolveEquivalenceClass(
889-
type,
890-
ArchetypeResolutionKind::CompleteWellFormed);
891-
if (!equivClass)
892-
return nullptr;
893-
894-
return equivClass->lookupNestedType(*builder, name);
886+
auto computeViaGSB = [&]() -> TypeDecl * {
887+
auto *builder = getGenericSignatureBuilder();
888+
auto equivClass =
889+
builder->resolveEquivalenceClass(
890+
type,
891+
ArchetypeResolutionKind::CompleteWellFormed);
892+
if (!equivClass)
893+
return nullptr;
894+
895+
return equivClass->lookupNestedType(*builder, name);
896+
};
897+
898+
auto computeViaRQM = [&]() {
899+
auto *machine = getRequirementMachine();
900+
return machine->lookupNestedType(type, name);
901+
};
902+
903+
auto &ctx = getASTContext();
904+
if (ctx.LangOpts.EnableRequirementMachine) {
905+
auto rqmResult = computeViaRQM();
906+
907+
#ifndef NDEBUG
908+
auto gsbResult = computeViaGSB();
909+
910+
if (gsbResult != rqmResult) {
911+
llvm::errs() << "RequirementMachine::lookupNestedType() is broken\n";
912+
llvm::errs() << "Generic signature: " << GenericSignature(this) << "\n";
913+
llvm::errs() << "Dependent type: "; type.dump(llvm::errs());
914+
llvm::errs() << "GenericSignatureBuilder says: ";
915+
gsbResult->dumpRef(llvm::errs());
916+
llvm::errs() << "\n";
917+
llvm::errs() << "RequirementMachine says: ";
918+
rqmResult->dumpRef(llvm::errs());
919+
llvm::errs() << "\n";
920+
getRequirementMachine()->dump(llvm::errs());
921+
abort();
922+
}
923+
#endif
924+
925+
return rqmResult;
926+
} else {
927+
return computeViaGSB();
928+
}
895929
}
896930

897931
unsigned GenericParamKey::findIndexIn(

lib/AST/RequirementMachine/EquivalenceClassMap.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class EquivalenceClass {
8787
return Superclass.hasValue();
8888
}
8989

90+
Type getSuperclassBound() const {
91+
return Superclass->getSuperclass();
92+
}
93+
9094
Type getSuperclassBound(
9195
TypeArrayView<GenericTypeParamType> genericParams,
9296
const ProtocolGraph &protos,
@@ -96,6 +100,10 @@ class EquivalenceClass {
96100
return ConcreteType.hasValue();
97101
}
98102

103+
Type getConcreteType() const {
104+
return ConcreteType->getConcreteType();
105+
}
106+
99107
Type getConcreteType(
100108
TypeArrayView<GenericTypeParamType> genericParams,
101109
const ProtocolGraph &protos,

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "swift/AST/ASTContext.h"
1515
#include "swift/AST/Decl.h"
1616
#include "swift/AST/GenericSignature.h"
17+
#include "swift/AST/Module.h"
1718
#include "swift/AST/PrettyStackTrace.h"
1819
#include "swift/AST/Requirement.h"
1920
#include "llvm/ADT/DenseSet.h"
@@ -746,3 +747,119 @@ Type RequirementMachine::getCanonicalTypeInContext(
746747
return getCanonicalTypeInContext(substType, genericParams);
747748
});
748749
}
750+
751+
/// Compare two associated types.
752+
static int compareAssociatedTypes(AssociatedTypeDecl *assocType1,
753+
AssociatedTypeDecl *assocType2) {
754+
// - by name.
755+
if (int result = assocType1->getName().str().compare(
756+
assocType2->getName().str()))
757+
return result;
758+
759+
// Prefer an associated type with no overrides (i.e., an anchor) to one
760+
// that has overrides.
761+
bool hasOverridden1 = !assocType1->getOverriddenDecls().empty();
762+
bool hasOverridden2 = !assocType2->getOverriddenDecls().empty();
763+
if (hasOverridden1 != hasOverridden2)
764+
return hasOverridden1 ? +1 : -1;
765+
766+
// - by protocol, so t_n_m.`P.T` < t_n_m.`Q.T` (given P < Q)
767+
auto proto1 = assocType1->getProtocol();
768+
auto proto2 = assocType2->getProtocol();
769+
if (int compareProtocols = TypeDecl::compare(proto1, proto2))
770+
return compareProtocols;
771+
772+
// Error case: if we have two associated types with the same name in the
773+
// same protocol, just tie-break based on address.
774+
if (assocType1 != assocType2)
775+
return assocType1 < assocType2 ? -1 : +1;
776+
777+
return 0;
778+
}
779+
780+
static void lookupConcreteNestedType(NominalTypeDecl *decl,
781+
Identifier name,
782+
SmallVectorImpl<TypeDecl *> &concreteDecls) {
783+
SmallVector<ValueDecl *, 2> foundMembers;
784+
decl->getParentModule()->lookupQualified(
785+
decl, DeclNameRef(name),
786+
NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers,
787+
foundMembers);
788+
for (auto member : foundMembers)
789+
concreteDecls.push_back(cast<TypeDecl>(member));
790+
}
791+
792+
static TypeDecl *
793+
findBestConcreteNestedType(SmallVectorImpl<TypeDecl *> &concreteDecls) {
794+
return *std::min_element(concreteDecls.begin(), concreteDecls.end(),
795+
[](TypeDecl *type1, TypeDecl *type2) {
796+
return TypeDecl::compare(type1, type2) < 0;
797+
});
798+
}
799+
800+
TypeDecl *
801+
RequirementMachine::lookupNestedType(Type depType, Identifier name) const {
802+
auto term = Impl->Context.getMutableTermForType(depType->getCanonicalType(),
803+
/*proto=*/nullptr);
804+
Impl->System.simplify(term);
805+
Impl->verify(term);
806+
807+
auto *equivClass = Impl->Map.lookUpEquivalenceClass(term);
808+
if (!equivClass)
809+
return nullptr;
810+
811+
// Look for types with the given name in protocols that we know about.
812+
AssociatedTypeDecl *bestAssocType = nullptr;
813+
SmallVector<TypeDecl *, 4> concreteDecls;
814+
815+
for (const auto *proto : equivClass->getConformsTo()) {
816+
// Look for an associated type and/or concrete type with this name.
817+
for (auto member : const_cast<ProtocolDecl *>(proto)->lookupDirect(name)) {
818+
// If this is an associated type, record whether it is the best
819+
// associated type we've seen thus far.
820+
if (auto assocType = dyn_cast<AssociatedTypeDecl>(member)) {
821+
// Retrieve the associated type anchor.
822+
assocType = assocType->getAssociatedTypeAnchor();
823+
824+
if (!bestAssocType ||
825+
compareAssociatedTypes(assocType, bestAssocType) < 0)
826+
bestAssocType = assocType;
827+
828+
continue;
829+
}
830+
831+
// If this is another type declaration, record it.
832+
if (auto type = dyn_cast<TypeDecl>(member)) {
833+
concreteDecls.push_back(type);
834+
continue;
835+
}
836+
}
837+
}
838+
839+
// If we haven't found anything yet but have a concrete type or a superclass,
840+
// look for a type in that.
841+
// FIXME: Shouldn't we always look here?
842+
if (!bestAssocType && concreteDecls.empty()) {
843+
Type typeToSearch;
844+
if (equivClass->isConcreteType())
845+
typeToSearch = equivClass->getConcreteType();
846+
else if (equivClass->hasSuperclassBound())
847+
typeToSearch = equivClass->getSuperclassBound();
848+
849+
if (typeToSearch)
850+
if (auto *decl = typeToSearch->getAnyNominal())
851+
lookupConcreteNestedType(decl, name, concreteDecls);
852+
}
853+
854+
if (bestAssocType) {
855+
assert(bestAssocType->getOverriddenDecls().empty() &&
856+
"Lookup should never keep a non-anchor associated type");
857+
return bestAssocType;
858+
859+
} else if (!concreteDecls.empty()) {
860+
// Find the best concrete type.
861+
return findBestConcreteNestedType(concreteDecls);
862+
}
863+
864+
return nullptr;
865+
}

0 commit comments

Comments
 (0)