Skip to content

Commit b080190

Browse files
authored
Merge pull request #42079 from slavapestov/rqm-simplify-type-for-term
RequirementMachine: Simplify getTypeForSymbolRange() a bit
2 parents 66999ce + 3cfbe03 commit b080190

File tree

4 files changed

+60
-145
lines changed

4 files changed

+60
-145
lines changed

lib/AST/RequirementMachine/GenericSignatureQueries.cpp

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,21 @@ bool RequirementMachine::isCanonicalTypeInContext(Type type) const {
307307
return !type.walk(Walker(*this));
308308
}
309309

310+
/// Given a type parameter 'T.A1.A2...An', a suffix length m where m <= n,
311+
/// and a replacement type U, produce the type 'U.A(n-m)...An' by replacing
312+
/// 'T.A1...A(n-m-1)' with 'U'.
313+
static Type substPrefixType(Type type, unsigned suffixLength, Type prefixType,
314+
GenericSignature sig) {
315+
if (suffixLength == 0)
316+
return prefixType;
317+
318+
auto *memberType = type->castTo<DependentMemberType>();
319+
auto substBaseType = substPrefixType(memberType->getBase(), suffixLength - 1,
320+
prefixType, sig);
321+
return memberType->substBaseType(substBaseType,
322+
LookUpConformanceInSignature(sig.getPointer()));
323+
}
324+
310325
/// Unlike most other queries, the input type can be any type, not just a
311326
/// type parameter.
312327
///
@@ -416,31 +431,9 @@ Type RequirementMachine::getCanonicalTypeInContext(
416431
abort();
417432
}
418433

419-
// Compute the type of the unresolved suffix term V, rooted in the
420-
// generic parameter τ_0_0.
421-
auto origType = Map.getRelativeTypeForTerm(term, prefix);
422-
423-
// Substitute τ_0_0 in the above relative type with the concrete type
424-
// for U.
425-
//
426-
// Example: if T == A.B.C and the longest valid prefix is A.B which
427-
// maps to a concrete type Foo<Int>, then we have:
428-
//
429-
// U == A.B
430-
// V == C
431-
//
432-
// prefixType == Foo<Int>
433-
// origType == τ_0_0.C
434-
// substType == Foo<Int>.C
435-
//
436-
auto substType = origType.subst(
437-
[&](SubstitutableType *type) -> Type {
438-
assert(cast<GenericTypeParamType>(type)->getDepth() == 0);
439-
assert(cast<GenericTypeParamType>(type)->getIndex() == 0);
440-
441-
return prefixType;
442-
},
443-
LookUpConformanceInSignature(Sig.getPointer()));
434+
// Compute the type of the unresolved suffix term V.
435+
auto substType = substPrefixType(t, term.size() - prefix.size(),
436+
prefixType, Sig);
444437

445438
// FIXME: Recursion guard is needed here
446439
return getCanonicalTypeInContext(substType, genericParams);

lib/AST/RequirementMachine/InterfaceType.cpp

Lines changed: 42 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -143,57 +143,6 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
143143
return MutableTerm(symbols);
144144
}
145145

146-
/// Map an associated type symbol to an associated type declaration.
147-
///
148-
/// Note that the protocol graph is not part of the caching key; each
149-
/// protocol graph is a subgraph of the global inheritance graph, so
150-
/// the specific choice of subgraph does not change the result.
151-
AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(Symbol symbol) {
152-
auto found = AssocTypes.find(symbol);
153-
if (found != AssocTypes.end())
154-
return found->second;
155-
156-
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
157-
auto name = symbol.getName();
158-
159-
AssociatedTypeDecl *assocType = nullptr;
160-
161-
// An associated type symbol [P1:A] stores a protocol 'P' and an
162-
// identifier 'A'.
163-
//
164-
// We map it back to a AssociatedTypeDecl by looking for associated
165-
// types named 'A' in 'P' and all protocols 'P' inherits. If there
166-
// are multiple candidates, we discard overrides, and then pick the
167-
// candidate that is minimal with respect to the linear order
168-
// defined by TypeDecl::compare().
169-
auto *proto = symbol.getProtocol();
170-
171-
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
172-
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
173-
174-
if (otherAssocType->getName() == name &&
175-
(assocType == nullptr ||
176-
TypeDecl::compare(otherAssocType->getProtocol(),
177-
assocType->getProtocol()) < 0)) {
178-
assocType = otherAssocType;
179-
}
180-
};
181-
182-
for (auto *otherAssocType : proto->getAssociatedTypeMembers()) {
183-
checkOtherAssocType(otherAssocType);
184-
}
185-
186-
for (auto *inheritedProto : getInheritedProtocols(proto)) {
187-
for (auto *otherAssocType : inheritedProto->getAssociatedTypeMembers()) {
188-
checkOtherAssocType(otherAssocType);
189-
}
190-
}
191-
192-
assert(assocType && "Need to look harder");
193-
AssocTypes[symbol] = assocType;
194-
return assocType;
195-
}
196-
197146
/// Find the most canonical associated type declaration with the given
198147
/// name among a set of conforming protocols stored in this property map
199148
/// entry.
@@ -230,18 +179,13 @@ AssociatedTypeDecl *PropertyBag::getAssociatedType(Identifier name) {
230179
return assocType;
231180
}
232181

233-
/// Compute the interface type for a range of symbols, with an optional
234-
/// root type.
235-
///
236-
/// If the root type is specified, we wrap it in a series of
237-
/// DependentMemberTypes. Otherwise, the root is computed from
238-
/// the first symbol of the range.
182+
/// Compute the interface type for a range of symbols.
239183
static Type
240-
getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
184+
getTypeForSymbolRange(const Symbol *begin, const Symbol *end,
241185
TypeArrayView<GenericTypeParamType> genericParams,
242186
const PropertyMap &map) {
243187
auto &ctx = map.getRewriteContext();
244-
Type result = root;
188+
Type result;
245189

246190
auto handleRoot = [&](GenericTypeParamType *genericParam) {
247191
assert(genericParam->isCanonical());
@@ -328,12 +272,13 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
328272
continue;
329273
}
330274

331-
// We should have a resolved type at this point.
332-
AssociatedTypeDecl *assocType;
275+
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
333276

277+
MutableTerm prefix;
334278
if (begin == iter) {
335-
// FIXME: Eliminate this case once merged associated types are gone.
336-
assocType = ctx.getAssociatedTypeForSymbol(symbol);
279+
// If the term begins with an associated type symbol, look for the
280+
// associated type in the protocol itself.
281+
prefix.add(Symbol::forProtocol(symbol.getProtocol(), ctx));
337282
} else {
338283
// The protocol stored in an associated type symbol appearing in a
339284
// canonical term is not necessarily the right protocol to look for
@@ -343,41 +288,40 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
343288
//
344289
// Instead, find all protocols that the prefix conforms to, and look
345290
// for an associated type in those protocols.
346-
MutableTerm prefix(begin, iter);
347-
assert(prefix.size() > 0);
291+
prefix.append(begin, iter);
292+
}
348293

349-
auto *props = map.lookUpProperties(prefix.rbegin(), prefix.rend());
350-
if (props == nullptr) {
351-
llvm::errs() << "Cannot build interface type for term "
352-
<< MutableTerm(begin, end) << "\n";
353-
llvm::errs() << "Prefix does not conform to any protocols: "
354-
<< prefix << "\n\n";
355-
map.dump(llvm::errs());
356-
abort();
357-
}
294+
auto *props = map.lookUpProperties(prefix.rbegin(), prefix.rend());
295+
if (props == nullptr) {
296+
llvm::errs() << "Cannot build interface type for term "
297+
<< MutableTerm(begin, end) << "\n";
298+
llvm::errs() << "Prefix does not conform to any protocols: "
299+
<< prefix << "\n\n";
300+
map.dump(llvm::errs());
301+
abort();
302+
}
358303

359-
// Assert that the associated type's protocol appears among the set
360-
// of protocols that the prefix conforms to.
361-
#ifndef NDEBUG
362-
auto conformsTo = props->getConformsTo();
363-
assert(std::find(conformsTo.begin(), conformsTo.end(),
364-
symbol.getProtocol())
365-
!= conformsTo.end());
366-
#endif
367-
368-
assocType = props->getAssociatedType(symbol.getName());
369-
if (assocType == nullptr) {
370-
llvm::errs() << "Cannot build interface type for term "
371-
<< MutableTerm(begin, end) << "\n";
372-
llvm::errs() << "Prefix term does not not have a nested type named "
373-
<< symbol.getName() << ": "
374-
<< prefix << "\n";
375-
llvm::errs() << "Property map entry: ";
376-
props->dump(llvm::errs());
377-
llvm::errs() << "\n\n";
378-
map.dump(llvm::errs());
379-
abort();
380-
}
304+
// Assert that the associated type's protocol appears among the set
305+
// of protocols that the prefix conforms to.
306+
#ifndef NDEBUG
307+
auto conformsTo = props->getConformsTo();
308+
assert(std::find(conformsTo.begin(), conformsTo.end(),
309+
symbol.getProtocol())
310+
!= conformsTo.end());
311+
#endif
312+
313+
auto *assocType = props->getAssociatedType(symbol.getName());
314+
if (assocType == nullptr) {
315+
llvm::errs() << "Cannot build interface type for term "
316+
<< MutableTerm(begin, end) << "\n";
317+
llvm::errs() << "Prefix term does not not have a nested type named "
318+
<< symbol.getName() << ": "
319+
<< prefix << "\n";
320+
llvm::errs() << "Property map entry: ";
321+
props->dump(llvm::errs());
322+
llvm::errs() << "\n\n";
323+
map.dump(llvm::errs());
324+
abort();
381325
}
382326

383327
result = DependentMemberType::get(result, assocType);
@@ -388,26 +332,12 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
388332

389333
Type PropertyMap::getTypeForTerm(Term term,
390334
TypeArrayView<GenericTypeParamType> genericParams) const {
391-
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
392-
genericParams, *this);
335+
return getTypeForSymbolRange(term.begin(), term.end(), genericParams, *this);
393336
}
394337

395338
Type PropertyMap::getTypeForTerm(const MutableTerm &term,
396339
TypeArrayView<GenericTypeParamType> genericParams) const {
397-
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
398-
genericParams, *this);
399-
}
400-
401-
Type PropertyMap::getRelativeTypeForTerm(
402-
const MutableTerm &term, const MutableTerm &prefix) const {
403-
assert(std::equal(prefix.begin(), prefix.end(), term.begin()));
404-
405-
auto genericParam =
406-
CanGenericTypeParamType::get(/*type sequence*/ false, 0, 0,
407-
Context.getASTContext());
408-
return getTypeForSymbolRange(
409-
term.begin() + prefix.size(), term.end(), genericParam,
410-
{ }, *this);
340+
return getTypeForSymbolRange(term.begin(), term.end(), genericParams, *this);
411341
}
412342

413343
/// Concrete type terms are written in terms of generic parameter types that

lib/AST/RequirementMachine/PropertyMap.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,6 @@ class PropertyMap {
232232
Type getTypeForTerm(const MutableTerm &term,
233233
TypeArrayView<GenericTypeParamType> genericParams) const;
234234

235-
Type getRelativeTypeForTerm(
236-
const MutableTerm &term, const MutableTerm &prefix) const;
237-
238235
Type getTypeFromSubstitutionSchema(
239236
Type schema,
240237
ArrayRef<Term> substitutions,

lib/AST/RequirementMachine/RewriteContext.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ class RewriteContext final {
5252
llvm::DenseMap<const ProtocolDecl *,
5353
llvm::TinyPtrVector<const ProtocolDecl *>> AllInherited;
5454

55-
/// Cache for associated type declarations.
56-
llvm::DenseMap<Symbol, AssociatedTypeDecl *> AssocTypes;
57-
5855
/// Requirement machines built from generic signatures.
5956
llvm::DenseMap<GenericSignature, RequirementMachine *> Machines;
6057

@@ -194,8 +191,6 @@ class RewriteContext final {
194191
ArrayRef<Term> substitutions,
195192
SmallVectorImpl<Term> &result);
196193

197-
AssociatedTypeDecl *getAssociatedTypeForSymbol(Symbol symbol);
198-
199194
//////////////////////////////////////////////////////////////////////////////
200195
///
201196
/// Construction of requirement machines for connected components in the

0 commit comments

Comments
 (0)