Skip to content

Commit da0130d

Browse files
committed
AST: Replace calls to substBaseType() with getAssociatedType()
1 parent a8ed7ba commit da0130d

File tree

4 files changed

+25
-17
lines changed

4 files changed

+25
-17
lines changed

lib/AST/RequirementMachine/ConcreteContraction.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,8 @@ ConcreteContraction::substTypeParameterRec(Type type, Position position) const {
302302
return std::nullopt;
303303
}
304304

305-
return assocType->getDeclaredInterfaceType()
306-
->castTo<DependentMemberType>()
307-
->substBaseType(*substBaseType);
305+
return conformance.getAssociatedType(
306+
*substBaseType, assocType->getDeclaredInterfaceType());
308307
}
309308

310309
// An unresolved DependentMemberType stores an identifier. Handle this

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
//===----------------------------------------------------------------------===//
3434

3535
#include "swift/AST/ASTContext.h"
36+
#include "swift/AST/ConformanceLookup.h"
3637
#include "swift/AST/Decl.h"
3738
#include "swift/AST/GenericEnvironment.h"
3839
#include "swift/AST/GenericSignature.h"
@@ -332,6 +333,8 @@ bool RequirementMachine::isReducedType(Type type) const {
332333
/// Given a type parameter 'T.A1.A2...An', a suffix length m where m <= n,
333334
/// and a replacement type U, produce the type 'U.A(n-m)...An' by replacing
334335
/// 'T.A1...A(n-m-1)' with 'U'.
336+
///
337+
/// FIXME: Remove this.
335338
static Type substPrefixType(Type type, unsigned suffixLength, Type prefixType,
336339
GenericSignature sig) {
337340
if (suffixLength == 0)
@@ -340,9 +343,12 @@ static Type substPrefixType(Type type, unsigned suffixLength, Type prefixType,
340343
auto *memberType = type->castTo<DependentMemberType>();
341344
auto substBaseType = substPrefixType(memberType->getBase(), suffixLength - 1,
342345
prefixType, sig);
343-
return memberType->substBaseType(
344-
substBaseType, LookUpConformanceInModule(),
345-
std::nullopt);
346+
auto *assocDecl = memberType->getAssocType();
347+
auto *proto = assocDecl->getProtocol();
348+
auto conformance = lookupConformance(substBaseType, proto);
349+
return conformance.getAssociatedType(
350+
substBaseType,
351+
assocDecl->getDeclaredInterfaceType());
346352
}
347353

348354
Type RequirementMachine::getReducedTypeParameter(
@@ -380,6 +386,8 @@ Type RequirementMachine::getReducedTypeParameter(
380386
//
381387
// Note that V can be empty if T is fully valid; we expect this to be
382388
// true most of the time.
389+
//
390+
// FIXME: Remove all of this.
383391
auto prefix = getLongestValidPrefix(term);
384392

385393
// Get a type (concrete or dependent) for U.

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150

151151
#include "RequirementLowering.h"
152152
#include "swift/AST/ASTContext.h"
153+
#include "swift/AST/ConformanceLookup.h"
153154
#include "swift/AST/Decl.h"
154155
#include "swift/AST/DiagnosticsSema.h"
155156
#include "swift/AST/Requirement.h"
@@ -663,9 +664,9 @@ struct InferRequirementsWalker : public TypeWalker {
663664
if (differentiableProtocol && fnTy->isDifferentiable()) {
664665
auto addSameTypeConstraint = [&](Type firstType,
665666
AssociatedTypeDecl *assocType) {
666-
auto secondType = assocType->getDeclaredInterfaceType()
667-
->castTo<DependentMemberType>()
668-
->substBaseType(firstType);
667+
auto conformance = lookupConformance(firstType, differentiableProtocol);
668+
auto secondType = conformance.getAssociatedType(
669+
firstType, assocType->getDeclaredInterfaceType());
669670
reqs.push_back({Requirement(RequirementKind::SameType,
670671
firstType, secondType),
671672
SourceLoc()});

lib/AST/Type.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3746,12 +3746,13 @@ void ParameterizedProtocolType::getRequirements(
37463746
auto argTypes = getArgs();
37473747
assert(argTypes.size() <= assocTypes.size());
37483748

3749+
auto conformance = lookupConformance(baseType, protoDecl);
3750+
37493751
for (unsigned i : indices(argTypes)) {
37503752
auto argType = argTypes[i];
37513753
auto *assocType = assocTypes[i];
3752-
auto subjectType = assocType->getDeclaredInterfaceType()
3753-
->castTo<DependentMemberType>()
3754-
->substBaseType(baseType);
3754+
auto subjectType = conformance.getAssociatedType(
3755+
baseType, assocType->getDeclaredInterfaceType());
37553756
reqs.emplace_back(RequirementKind::SameType, subjectType, argType);
37563757
}
37573758
}
@@ -4457,14 +4458,13 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
44574458
auto associatedTypeLookup =
44584459
differentiableProtocol->lookupDirect(ctx.Id_TangentVector);
44594460
assert(associatedTypeLookup.size() == 1);
4460-
auto *dependentType = DependentMemberType::get(
4461-
differentiableProtocol->getDeclaredInterfaceType(),
4462-
cast<AssociatedTypeDecl>(associatedTypeLookup[0]));
4461+
auto *assocDecl = cast<AssociatedTypeDecl>(associatedTypeLookup[0]);
44634462

44644463
// Try to get the `TangentVector` associated type of `base`.
44654464
// Return the associated type if it is valid.
4466-
auto assocTy =
4467-
dependentType->substBaseType(this, lookupConformance, std::nullopt);
4465+
auto conformance = swift::lookupConformance(this, differentiableProtocol);
4466+
auto assocTy = conformance.getAssociatedType(
4467+
this, assocDecl->getDeclaredInterfaceType());
44684468
if (!assocTy->hasError())
44694469
return cache(TangentSpace::getTangentVector(assocTy));
44704470

0 commit comments

Comments
 (0)