Skip to content

Commit 4bb8f46

Browse files
authored
Merge pull request #76536 from slavapestov/small-subst-cleanups
Tiny optimization and cleanups
2 parents e044a37 + a27d6cf commit 4bb8f46

17 files changed

+226
-248
lines changed

include/swift/AST/Types.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7158,22 +7158,6 @@ class DependentMemberType : public TypeBase {
71587158
AssociatedTypeDecl *getAssocType() const {
71597159
return NameOrAssocType.dyn_cast<AssociatedTypeDecl *>();
71607160
}
7161-
7162-
/// Substitute the base type, looking up our associated type in it if it is
7163-
/// non-dependent. Returns null if the member could not be found in the new
7164-
/// base.
7165-
Type substBaseType(Type base);
7166-
7167-
/// Substitute the base type, looking up our associated type in it if it is
7168-
/// non-dependent. Returns null if the member could not be found in the new
7169-
/// base.
7170-
Type substBaseType(Type base, LookupConformanceFn lookupConformance,
7171-
SubstOptions options);
7172-
7173-
/// Substitute the root generic type, looking up the chain of associated types.
7174-
/// Returns null if the member could not be found in the new root.
7175-
Type substRootParam(Type newRoot, LookupConformanceFn lookupConformance,
7176-
SubstOptions options);
71777161

71787162
// Implement isa/cast/dyncast/etc.
71797163
static bool classof(const TypeBase *T) {

lib/AST/ASTDemangler.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -305,29 +305,29 @@ Type ASTBuilder::createBoundGenericType(GenericTypeDecl *decl,
305305
if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
306306
return BoundGenericType::get(nominalDecl, parent, args);
307307

308-
// Combine the substitutions from our parent type with our generic
309-
// arguments.
310-
TypeSubstitutionMap subs;
311-
if (parent)
312-
subs = parent->getContextSubstitutions(decl->getDeclContext());
313-
314308
auto *aliasDecl = cast<TypeAliasDecl>(decl);
309+
auto *dc = aliasDecl->getDeclContext();
315310

316-
auto genericSig = aliasDecl->getGenericSignature();
317-
for (unsigned i = 0, e = args.size(); i < e; ++i) {
318-
auto origTy = genericSig.getInnermostGenericParams()[i];
319-
auto substTy = args[i];
311+
SmallVector<Type, 2> subs;
320312

321-
subs[origTy->getCanonicalType()->castTo<GenericTypeParamType>()] =
322-
substTy;
313+
// Combine the substitutions from our parent type with our generic
314+
// arguments.
315+
if (dc->isLocalContext()) {
316+
for (auto *param : dc->getGenericSignatureOfContext().getGenericParams()) {
317+
subs.push_back(param);
318+
}
319+
} else if (parent) {
320+
auto parentSubs = parent->getContextSubstitutionMap(
321+
dc).getReplacementTypes();
322+
subs.append(parentSubs.begin(), parentSubs.end());
323323
}
324324

325-
auto subMap = SubstitutionMap::get(genericSig,
326-
QueryTypeSubstitutionMap{subs},
327-
LookUpConformanceInModule());
328-
if (!subMap)
329-
return Type();
325+
auto genericSig = aliasDecl->getGenericSignature();
326+
ASSERT(genericSig.getInnermostGenericParams().size() == args.size());
327+
subs.append(args.begin(), args.end());
330328

329+
auto subMap = SubstitutionMap::get(genericSig, subs,
330+
LookUpConformanceInModule());
331331
return aliasDecl->getDeclaredInterfaceType().subst(subMap);
332332
}
333333

lib/AST/ProtocolConformanceRef.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,28 +198,40 @@ Type ProtocolConformanceRef::getAssociatedType(Type conformingType,
198198
Type assocType) const {
199199
if (isPack()) {
200200
auto *pack = getPack();
201-
assert(conformingType->isEqual(pack->getType()));
201+
ASSERT(conformingType->isEqual(pack->getType()));
202202
return pack->getAssociatedType(assocType);
203203
}
204204

205-
assert(!isConcrete() || getConcrete()->getType()->isEqual(conformingType));
206-
207205
auto type = assocType->getCanonicalType();
208-
auto proto = getRequirement();
209206

210207
// Fast path for generic parameters.
211-
if (isa<GenericTypeParamType>(type)) {
212-
assert(type->isEqual(proto->getSelfInterfaceType()) &&
208+
if (auto paramTy = dyn_cast<GenericTypeParamType>(type)) {
209+
ASSERT(paramTy->getDepth() == 0 && paramTy->getIndex() == 0 &&
213210
"type parameter in protocol was not Self");
214211
return conformingType;
215212
}
216213

217-
// Fast path for dependent member types on 'Self' of our associated types.
218-
auto memberType = cast<DependentMemberType>(type);
219-
if (memberType.getBase()->isEqual(proto->getSelfInterfaceType()) &&
220-
memberType->getAssocType()->getProtocol() == proto &&
221-
isConcrete())
222-
return getConcrete()->getTypeWitness(memberType->getAssocType());
214+
if (isInvalid())
215+
return ErrorType::get(assocType->getASTContext());
216+
217+
auto proto = getRequirement();
218+
219+
if (isConcrete()) {
220+
if (auto selfType = conformingType->getAs<DynamicSelfType>())
221+
conformingType = selfType->getSelfType();
222+
ASSERT(getConcrete()->getType()->isEqual(conformingType));
223+
224+
// Fast path for dependent member types on 'Self' of our associated types.
225+
auto memberType = cast<DependentMemberType>(type);
226+
if (memberType.getBase()->isEqual(proto->getSelfInterfaceType()) &&
227+
memberType->getAssocType()->getProtocol() == proto) {
228+
auto witnessType = getConcrete()->getTypeWitness(
229+
memberType->getAssocType());
230+
if (!witnessType)
231+
return ErrorType::get(assocType->getASTContext());
232+
return witnessType;
233+
}
234+
}
223235

224236
// General case: consult the substitution map.
225237
auto substMap =

lib/AST/RequirementMachine/ConcreteContraction.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,8 @@ ConcreteContraction::substTypeParameterRec(Type type, Position position) const {
302302
return std::nullopt;
303303
}
304304

305-
return assocType->getDeclaredInterfaceType()
306-
->castTo<DependentMemberType>()
307-
->substBaseType(*substBaseType);
305+
return conformance.getAssociatedType(
306+
*substBaseType, assocType->getDeclaredInterfaceType());
308307
}
309308

310309
// An unresolved DependentMemberType stores an identifier. Handle this

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
//===----------------------------------------------------------------------===//
3434

3535
#include "swift/AST/ASTContext.h"
36+
#include "swift/AST/ConformanceLookup.h"
3637
#include "swift/AST/Decl.h"
3738
#include "swift/AST/GenericEnvironment.h"
3839
#include "swift/AST/GenericSignature.h"
@@ -332,6 +333,8 @@ bool RequirementMachine::isReducedType(Type type) const {
332333
/// Given a type parameter 'T.A1.A2...An', a suffix length m where m <= n,
333334
/// and a replacement type U, produce the type 'U.A(n-m)...An' by replacing
334335
/// 'T.A1...A(n-m-1)' with 'U'.
336+
///
337+
/// FIXME: Remove this.
335338
static Type substPrefixType(Type type, unsigned suffixLength, Type prefixType,
336339
GenericSignature sig) {
337340
if (suffixLength == 0)
@@ -340,9 +343,12 @@ static Type substPrefixType(Type type, unsigned suffixLength, Type prefixType,
340343
auto *memberType = type->castTo<DependentMemberType>();
341344
auto substBaseType = substPrefixType(memberType->getBase(), suffixLength - 1,
342345
prefixType, sig);
343-
return memberType->substBaseType(
344-
substBaseType, LookUpConformanceInModule(),
345-
std::nullopt);
346+
auto *assocDecl = memberType->getAssocType();
347+
auto *proto = assocDecl->getProtocol();
348+
auto conformance = lookupConformance(substBaseType, proto);
349+
return conformance.getAssociatedType(
350+
substBaseType,
351+
assocDecl->getDeclaredInterfaceType());
346352
}
347353

348354
Type RequirementMachine::getReducedTypeParameter(
@@ -380,6 +386,8 @@ Type RequirementMachine::getReducedTypeParameter(
380386
//
381387
// Note that V can be empty if T is fully valid; we expect this to be
382388
// true most of the time.
389+
//
390+
// FIXME: Remove all of this.
383391
auto prefix = getLongestValidPrefix(term);
384392

385393
// Get a type (concrete or dependent) for U.

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150

151151
#include "RequirementLowering.h"
152152
#include "swift/AST/ASTContext.h"
153+
#include "swift/AST/ConformanceLookup.h"
153154
#include "swift/AST/Decl.h"
154155
#include "swift/AST/DiagnosticsSema.h"
155156
#include "swift/AST/Requirement.h"
@@ -663,9 +664,9 @@ struct InferRequirementsWalker : public TypeWalker {
663664
if (differentiableProtocol && fnTy->isDifferentiable()) {
664665
auto addSameTypeConstraint = [&](Type firstType,
665666
AssociatedTypeDecl *assocType) {
666-
auto secondType = assocType->getDeclaredInterfaceType()
667-
->castTo<DependentMemberType>()
668-
->substBaseType(firstType);
667+
auto conformance = lookupConformance(firstType, differentiableProtocol);
668+
auto secondType = conformance.getAssociatedType(
669+
firstType, assocType->getDeclaredInterfaceType());
669670
reqs.push_back({Requirement(RequirementKind::SameType,
670671
firstType, secondType),
671672
SourceLoc()});

lib/AST/Type.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3746,12 +3746,13 @@ void ParameterizedProtocolType::getRequirements(
37463746
auto argTypes = getArgs();
37473747
assert(argTypes.size() <= assocTypes.size());
37483748

3749+
auto conformance = lookupConformance(baseType, protoDecl);
3750+
37493751
for (unsigned i : indices(argTypes)) {
37503752
auto argType = argTypes[i];
37513753
auto *assocType = assocTypes[i];
3752-
auto subjectType = assocType->getDeclaredInterfaceType()
3753-
->castTo<DependentMemberType>()
3754-
->substBaseType(baseType);
3754+
auto subjectType = conformance.getAssociatedType(
3755+
baseType, assocType->getDeclaredInterfaceType());
37553756
reqs.emplace_back(RequirementKind::SameType, subjectType, argType);
37563757
}
37573758
}
@@ -4457,14 +4458,13 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
44574458
auto associatedTypeLookup =
44584459
differentiableProtocol->lookupDirect(ctx.Id_TangentVector);
44594460
assert(associatedTypeLookup.size() == 1);
4460-
auto *dependentType = DependentMemberType::get(
4461-
differentiableProtocol->getDeclaredInterfaceType(),
4462-
cast<AssociatedTypeDecl>(associatedTypeLookup[0]));
4461+
auto *assocDecl = cast<AssociatedTypeDecl>(associatedTypeLookup[0]);
44634462

44644463
// Try to get the `TangentVector` associated type of `base`.
44654464
// Return the associated type if it is valid.
4466-
auto assocTy =
4467-
dependentType->substBaseType(this, lookupConformance, std::nullopt);
4465+
auto conformance = swift::lookupConformance(this, differentiableProtocol);
4466+
auto assocTy = conformance.getAssociatedType(
4467+
this, assocDecl->getDeclaredInterfaceType());
44684468
if (!assocTy->hasError())
44694469
return cache(TangentSpace::getTangentVector(assocTy));
44704470

0 commit comments

Comments
 (0)