Skip to content

Commit b004396

Browse files
Merge pull request #40269 from AnthonyLatsis/assoc-inference-system
AssociatedTypeInference: Initial refactoring of abstract type witness inference
2 parents e30ba5f + 2ebb123 commit b004396

File tree

13 files changed

+1444
-277
lines changed

13 files changed

+1444
-277
lines changed

include/swift/AST/ASTContext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,7 @@ class ASTContext final {
11671167
/// conformance itself, along with a bit indicating whether this diagnostic
11681168
/// produces an error.
11691169
struct DelayedConformanceDiag {
1170-
ValueDecl *Requirement;
1170+
const ValueDecl *Requirement;
11711171
std::function<void()> Callback;
11721172
bool IsError;
11731173
};

include/swift/AST/Decl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3495,7 +3495,9 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
34953495
SmallVectorImpl<ProtocolConformance *> &conformances) const;
34963496

34973497
/// Retrieve all of the protocols that this nominal type conforms to.
3498-
SmallVector<ProtocolDecl *, 2> getAllProtocols() const;
3498+
///
3499+
/// \param sorted Whether to sort the protocols in canonical order.
3500+
SmallVector<ProtocolDecl *, 2> getAllProtocols(bool sorted = false) const;
34993501

35003502
/// Retrieve all of the protocol conformances for this nominal type.
35013503
SmallVector<ProtocolConformance *, 2> getAllConformances(

include/swift/Basic/LangOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ namespace swift {
530530
RequirementMachineMode RequirementMachineInferredSignatures =
531531
RequirementMachineMode::Disabled;
532532

533+
/// Enables dumping type witness systems from associated type inference.
534+
bool DumpTypeWitnessSystems = false;
535+
533536
/// Sets the target we are building for and updates platform conditions
534537
/// to match.
535538
///

include/swift/Option/FrontendOptions.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ def requirement_machine_max_concrete_nesting : Joined<["-"], "requirement-machin
352352
Flags<[FrontendOption, HelpHidden, DoesNotAffectIncrementalBuild]>,
353353
HelpText<"Set the maximum concrete type nesting depth before giving up">;
354354

355+
def dump_type_witness_systems : Flag<["-"], "dump-type-witness-systems">,
356+
HelpText<"Enables dumping type witness systems from associated type inference">;
357+
355358
def debug_generic_signatures : Flag<["-"], "debug-generic-signatures">,
356359
HelpText<"Debug generic signatures">;
357360

lib/AST/ConformanceLookupTable.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,8 +1061,8 @@ void ConformanceLookupTable::lookupConformances(
10611061
}
10621062

10631063
void ConformanceLookupTable::getAllProtocols(
1064-
NominalTypeDecl *nominal,
1065-
SmallVectorImpl<ProtocolDecl *> &scratch) {
1064+
NominalTypeDecl *nominal, SmallVectorImpl<ProtocolDecl *> &scratch,
1065+
bool sorted) {
10661066
// We need to expand all implied conformances to find the complete
10671067
// set of protocols to which this nominal type conforms.
10681068
updateLookupTable(nominal, ConformanceStage::ExpandedImplied);
@@ -1075,7 +1075,9 @@ void ConformanceLookupTable::getAllProtocols(
10751075
scratch.push_back(conformance.first);
10761076
}
10771077

1078-
// FIXME: sort the protocols in some canonical order?
1078+
if (sorted) {
1079+
llvm::array_pod_sort(scratch.begin(), scratch.end(), TypeDecl::compare);
1080+
}
10791081
}
10801082

10811083
int ConformanceLookupTable::compareProtocolConformances(

lib/AST/ConformanceLookupTable.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,12 @@ class ConformanceLookupTable : public ASTAllocated<ConformanceLookupTable> {
451451
SmallVectorImpl<ConformanceDiagnostic> *diagnostics);
452452

453453
/// Retrieve the complete set of protocols to which this nominal
454-
/// type conforms.
454+
/// type conforms (if the set contains a protocol, the same is true for any
455+
/// inherited protocols).
456+
///
457+
/// \param sorted Whether to sort the protocols in canonical order.
455458
void getAllProtocols(NominalTypeDecl *nominal,
456-
SmallVectorImpl<ProtocolDecl *> &scratch);
459+
SmallVectorImpl<ProtocolDecl *> &scratch, bool sorted);
457460

458461
/// Retrieve the complete set of protocol conformances for this
459462
/// nominal type.

lib/AST/ProtocolConformance.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,11 +1316,12 @@ bool NominalTypeDecl::lookupConformance(
13161316
conformances);
13171317
}
13181318

1319-
SmallVector<ProtocolDecl *, 2> NominalTypeDecl::getAllProtocols() const {
1319+
SmallVector<ProtocolDecl *, 2>
1320+
NominalTypeDecl::getAllProtocols(bool sorted) const {
13201321
prepareConformanceTable();
13211322
SmallVector<ProtocolDecl *, 2> result;
1322-
ConformanceTable->getAllProtocols(const_cast<NominalTypeDecl *>(this),
1323-
result);
1323+
ConformanceTable->getAllProtocols(const_cast<NominalTypeDecl *>(this), result,
1324+
sorted);
13241325
return result;
13251326
}
13261327

lib/Frontend/CompilerInvocation.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,8 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
959959
}
960960
}
961961

962+
Opts.DumpTypeWitnessSystems = Args.hasArg(OPT_dump_type_witness_systems);
963+
962964
return HadError || UnsupportedOS || UnsupportedArch;
963965
}
964966

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5383,8 +5383,8 @@ void swift::diagnoseConformanceFailure(Type T,
53835383
}
53845384

53855385
void ConformanceChecker::diagnoseOrDefer(
5386-
ValueDecl *requirement, bool isError,
5387-
std::function<void(NormalProtocolConformance *)> fn) {
5386+
const ValueDecl *requirement, bool isError,
5387+
std::function<void(NormalProtocolConformance *)> fn) {
53885388
if (isError)
53895389
Conformance->setInvalid();
53905390

lib/Sema/TypeCheckProtocol.h

Lines changed: 143 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -100,61 +100,6 @@ CheckTypeWitnessResult checkTypeWitness(Type type,
100100
const NormalProtocolConformance *Conf,
101101
SubstOptions options = None);
102102

103-
/// Describes the means of inferring an abstract type witness.
104-
enum class AbstractTypeWitnessKind : uint8_t {
105-
/// The type witness was inferred via a same-type-to-concrete constraint
106-
/// in a protocol requirement signature.
107-
Fixed,
108-
109-
/// The type witness was inferred via a defaulted associated type.
110-
Default,
111-
112-
/// The type witness was inferred to a generic parameter of the
113-
/// conforming type.
114-
GenericParam,
115-
};
116-
117-
/// A type witness inferred without the aid of a specific potential
118-
/// value witness.
119-
class AbstractTypeWitness {
120-
AbstractTypeWitnessKind Kind;
121-
AssociatedTypeDecl *AssocType;
122-
Type TheType;
123-
124-
/// When this is a default type witness, the declaration responsible for it.
125-
/// May not necessarilly match \c AssocType.
126-
AssociatedTypeDecl *DefaultedAssocType;
127-
128-
AbstractTypeWitness(AbstractTypeWitnessKind Kind,
129-
AssociatedTypeDecl *AssocType, Type TheType,
130-
AssociatedTypeDecl *DefaultedAssocType)
131-
: Kind(Kind), AssocType(AssocType), TheType(TheType),
132-
DefaultedAssocType(DefaultedAssocType) {
133-
assert(AssocType && TheType);
134-
}
135-
136-
public:
137-
static AbstractTypeWitness forFixed(AssociatedTypeDecl *assocType, Type type);
138-
139-
static AbstractTypeWitness forDefault(AssociatedTypeDecl *assocType,
140-
Type type,
141-
AssociatedTypeDecl *defaultedAssocType);
142-
143-
static AbstractTypeWitness forGenericParam(AssociatedTypeDecl *assocType,
144-
Type type);
145-
146-
public:
147-
AbstractTypeWitnessKind getKind() const { return Kind; }
148-
149-
AssociatedTypeDecl *getAssocType() const { return AssocType; }
150-
151-
Type getType() const { return TheType; }
152-
153-
AssociatedTypeDecl *getDefaultedAssocType() const {
154-
return DefaultedAssocType;
155-
}
156-
};
157-
158103
/// The set of associated types that have been inferred by matching
159104
/// the given value witness to its corresponding requirement.
160105
struct InferredAssociatedTypesByWitness {
@@ -855,9 +800,8 @@ class ConformanceChecker : public WitnessChecker {
855800
///
856801
/// \param fn A function to call to emit the actual diagnostic. If
857802
/// diagnostics are being deferred,
858-
void diagnoseOrDefer(
859-
ValueDecl *requirement, bool isError,
860-
std::function<void(NormalProtocolConformance *)> fn);
803+
void diagnoseOrDefer(const ValueDecl *requirement, bool isError,
804+
std::function<void(NormalProtocolConformance *)> fn);
861805

862806
ArrayRef<MissingWitness> getLocalMissingWitness() {
863807
return GlobalMissingWitnesses.getArrayRef().
@@ -929,6 +873,136 @@ class ConformanceChecker : public WitnessChecker {
929873
llvm::function_ref<bool(AbstractFunctionDecl *)>predicate);
930874
};
931875

876+
/// A system for recording and probing the intergrity of a type witness solution
877+
/// for a set of unresolved associated type declarations.
878+
///
879+
/// Right now can reason only about abstract type witnesses, i.e., same-type
880+
/// constraints, default type definitions, and bindings to generic parameters.
881+
class TypeWitnessSystem final {
882+
/// Equivalence classes are used on demand to express equivalences between
883+
/// witness candidates and reflect changes to resolved types across their
884+
/// members.
885+
class EquivalenceClass final {
886+
/// The pointer:
887+
/// - The resolved type for witness candidates belonging to this equivalence
888+
/// class. The resolved type may be a type parameter, but cannot directly
889+
/// pertain to a name variable in the owning system; instead, witness
890+
/// candidates that should resolve to the same type share an equivalence
891+
/// class.
892+
/// The int:
893+
/// - A flag indicating whether the resolved type is ambiguous. When set,
894+
/// the resolved type is null.
895+
llvm::PointerIntPair<Type, 1, bool> ResolvedTyAndIsAmbiguous;
896+
897+
public:
898+
EquivalenceClass(Type ty) : ResolvedTyAndIsAmbiguous(ty, false) {}
899+
900+
EquivalenceClass(const EquivalenceClass &) = delete;
901+
EquivalenceClass(EquivalenceClass &&) = delete;
902+
EquivalenceClass &operator=(const EquivalenceClass &) = delete;
903+
EquivalenceClass &operator=(EquivalenceClass &&) = delete;
904+
905+
Type getResolvedType() const {
906+
return ResolvedTyAndIsAmbiguous.getPointer();
907+
}
908+
void setResolvedType(Type ty);
909+
910+
bool isAmbiguous() const {
911+
return ResolvedTyAndIsAmbiguous.getInt();
912+
}
913+
void setAmbiguous() {
914+
ResolvedTyAndIsAmbiguous = {nullptr, true};
915+
}
916+
};
917+
918+
/// A type witness candidate for a name variable.
919+
struct TypeWitnessCandidate final {
920+
/// The defaulted associated type declaration correlating with this
921+
/// candidate, if present.
922+
const AssociatedTypeDecl *DefaultedAssocType;
923+
924+
/// The equivalence class of this candidate.
925+
EquivalenceClass *EquivClass;
926+
};
927+
928+
/// The set of equivalence classes in the system.
929+
llvm::SmallPtrSet<EquivalenceClass *, 4> EquivalenceClasses;
930+
931+
/// The mapping from name variables (the names of unresolved associated
932+
/// type declarations) to their corresponding type witness candidates.
933+
llvm::SmallDenseMap<Identifier, TypeWitnessCandidate, 4> TypeWitnesses;
934+
935+
public:
936+
TypeWitnessSystem(ArrayRef<AssociatedTypeDecl *> assocTypes);
937+
~TypeWitnessSystem();
938+
939+
TypeWitnessSystem(const TypeWitnessSystem &) = delete;
940+
TypeWitnessSystem(TypeWitnessSystem &&) = delete;
941+
TypeWitnessSystem &operator=(const TypeWitnessSystem &) = delete;
942+
TypeWitnessSystem &operator=(TypeWitnessSystem &&) = delete;
943+
944+
/// Get the resolved type witness for the associated type with the given name.
945+
Type getResolvedTypeWitness(Identifier name) const;
946+
bool hasResolvedTypeWitness(Identifier name) const;
947+
948+
/// Get the defaulted associated type relating to the resolved type witness
949+
/// for the associated type with the given name, if present.
950+
const AssociatedTypeDecl *getDefaultedAssocType(Identifier name) const;
951+
952+
/// Record a type witness for the given associated type name.
953+
///
954+
/// \note This need not lead to the resolution of a type witness, e.g.
955+
/// an associated type may be defaulted to another.
956+
void addTypeWitness(Identifier name, Type type);
957+
958+
/// Record a default type witness.
959+
///
960+
/// \param defaultedAssocType The specific associated type declaration that
961+
/// defines the given default type.
962+
///
963+
/// \note This need not lead to the resolution of a type witness.
964+
void addDefaultTypeWitness(Type type,
965+
const AssociatedTypeDecl *defaultedAssocType);
966+
967+
/// Record the given same-type requirement, if regarded of interest to
968+
/// the system.
969+
///
970+
/// \note This need not lead to the resolution of a type witness.
971+
void addSameTypeRequirement(const Requirement &req);
972+
973+
void dump(llvm::raw_ostream &out,
974+
const NormalProtocolConformance *conformance) const;
975+
976+
private:
977+
/// Form an equivalence between the given name variables.
978+
void addEquivalence(Identifier name1, Identifier name2);
979+
980+
/// Merge \p equivClass2 into \p equivClass1.
981+
///
982+
/// \note This will delete \p equivClass2 after migrating its members to
983+
/// \p equivClass1.
984+
void mergeEquivalenceClasses(EquivalenceClass *equivClass1,
985+
const EquivalenceClass *equivClass2);
986+
987+
/// The result of comparing two resolved types targeting a single equivalence
988+
/// class, in terms of their relative impact on solving the system.
989+
enum class ResolvedTypeComparisonResult {
990+
/// The first resolved type is a better choice than the second one.
991+
Better,
992+
993+
/// The first resolved type is an equivalent or worse choice than the
994+
/// second one.
995+
EquivalentOrWorse,
996+
997+
/// Both resolved types are concrete and mutually exclusive.
998+
Ambiguity
999+
};
1000+
1001+
/// Compare the given resolved types as targeting a single equivalence class,
1002+
/// in terms of the their relative impact on solving the system.
1003+
static ResolvedTypeComparisonResult compareResolvedTypes(Type ty1, Type ty2);
1004+
};
1005+
9321006
/// Captures the state needed to infer associated types.
9331007
class AssociatedTypeInference {
9341008
/// The type checker we'll need to validate declarations etc.
@@ -956,7 +1030,7 @@ class AssociatedTypeInference {
9561030
typeWitnesses;
9571031

9581032
/// Information about a failed, defaulted associated type.
959-
AssociatedTypeDecl *failedDefaultedAssocType = nullptr;
1033+
const AssociatedTypeDecl *failedDefaultedAssocType = nullptr;
9601034
Type failedDefaultedWitness;
9611035
CheckTypeWitnessResult failedDefaultedResult;
9621036

@@ -1005,23 +1079,20 @@ class AssociatedTypeInference {
10051079
ConformanceChecker &checker,
10061080
const llvm::SetVector<AssociatedTypeDecl *> &assocTypes);
10071081

1008-
/// Compute a "fixed" type witness for an associated type, e.g.,
1009-
/// if the refined protocol requires it to be equivalent to some other type.
1010-
Type computeFixedTypeWitness(AssociatedTypeDecl *assocType);
1011-
10121082
/// Compute the default type witness from an associated type default,
10131083
/// if there is one.
1014-
Optional<AbstractTypeWitness>
1015-
computeDefaultTypeWitness(AssociatedTypeDecl *assocType);
1084+
Optional<std::pair<AssociatedTypeDecl *, Type>>
1085+
computeDefaultTypeWitness(AssociatedTypeDecl *assocType) const;
10161086

10171087
/// Compute the "derived" type witness for an associated type that is
10181088
/// known to the compiler.
10191089
std::pair<Type, TypeDecl *>
10201090
computeDerivedTypeWitness(AssociatedTypeDecl *assocType);
10211091

1022-
/// Compute a type witness without using a specific potential witness.
1023-
Optional<AbstractTypeWitness>
1024-
computeAbstractTypeWitness(AssociatedTypeDecl *assocType);
1092+
/// Collect abstract type witnesses and feed them to the given system.
1093+
void collectAbstractTypeWitnesses(
1094+
TypeWitnessSystem &system,
1095+
ArrayRef<AssociatedTypeDecl *> unresolvedAssocTypes) const;
10251096

10261097
/// Substitute the current type witnesses into the given interface type.
10271098
Type substCurrentTypeWitnesses(Type type);
@@ -1040,14 +1111,12 @@ class AssociatedTypeInference {
10401111
/// requirements of the given constrained extension.
10411112
bool checkConstrainedExtension(ExtensionDecl *ext);
10421113

1043-
/// Validate the current tentative solution represented by \p typeWitnesses
1044-
/// and attempt to resolve abstract type witnesses for associated types that
1045-
/// could not be inferred otherwise.
1114+
/// Attempt to infer abstract type witnesses for the given set of associated
1115+
/// types.
10461116
///
10471117
/// \returns \c nullptr, or the associated type that failed.
1048-
AssociatedTypeDecl *
1049-
completeSolution(ArrayRef<AssociatedTypeDecl *> unresolvedAssocTypes,
1050-
unsigned reqDepth);
1118+
AssociatedTypeDecl *inferAbstractTypeWitnesses(
1119+
ArrayRef<AssociatedTypeDecl *> unresolvedAssocTypes, unsigned reqDepth);
10511120

10521121
/// Top-level operation to find solutions for the given unresolved
10531122
/// associated types.

0 commit comments

Comments
 (0)