Skip to content

Commit 28c1178

Browse files
committed
RequirementMachine: Introduce TypeAliasRequirementsRequest
This is a verbatim copy of the GenericSignatureBuilder's somewhat questionable (but necessary for source compatibility) logic where protocol typealiases with the same name as some other associated type imply a same-type requirement. The related diagnostics are there too, but only emitted when -requirement-machine-protocol-signatures=on; in 'verify' mode, the GSB will emit the same diagnostics.
1 parent 17f1cf9 commit 28c1178

File tree

6 files changed

+339
-4
lines changed

6 files changed

+339
-4
lines changed

include/swift/AST/Decl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4225,6 +4225,7 @@ class ProtocolDecl final : public NominalTypeDecl {
42254225
friend class SuperclassDeclRequest;
42264226
friend class SuperclassTypeRequest;
42274227
friend class StructuralRequirementsRequest;
4228+
friend class TypeAliasRequirementsRequest;
42284229
friend class ProtocolDependenciesRequest;
42294230
friend class RequirementSignatureRequest;
42304231
friend class RequirementSignatureRequestRQM;
@@ -4421,6 +4422,11 @@ class ProtocolDecl final : public NominalTypeDecl {
44214422
/// instead.
44224423
ArrayRef<StructuralRequirement> getStructuralRequirements() const;
44234424

4425+
/// Retrieve same-type requirements implied by protocol typealiases with the
4426+
/// same name as associated types, and diagnose cases that are better expressed
4427+
/// via a 'where' clause.
4428+
ArrayRef<Requirement> getTypeAliasRequirements() const;
4429+
44244430
/// Get the list of protocols appearing on the right hand side of conformance
44254431
/// requirements. Computed from the structural requirements, above.
44264432
ArrayRef<ProtocolDecl *> getProtocolDependencies() const;

include/swift/AST/TypeCheckRequests.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,25 @@ class StructuralRequirementsRequest :
387387
bool isCached() const { return true; }
388388
};
389389

390+
class TypeAliasRequirementsRequest :
391+
public SimpleRequest<TypeAliasRequirementsRequest,
392+
ArrayRef<Requirement>(ProtocolDecl *),
393+
RequestFlags::Cached> {
394+
public:
395+
using SimpleRequest::SimpleRequest;
396+
397+
private:
398+
friend SimpleRequest;
399+
400+
// Evaluation.
401+
ArrayRef<Requirement>
402+
evaluate(Evaluator &evaluator, ProtocolDecl *proto) const;
403+
404+
public:
405+
// Caching.
406+
bool isCached() const { return true; }
407+
};
408+
390409
class ProtocolDependenciesRequest :
391410
public SimpleRequest<ProtocolDependenciesRequest,
392411
ArrayRef<ProtocolDecl *>(ProtocolDecl *),

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ SWIFT_REQUEST(TypeChecker, RequirementRequest,
227227
SWIFT_REQUEST(TypeChecker, StructuralRequirementsRequest,
228228
ArrayRef<StructuralRequirement>(ProtocolDecl *), Cached,
229229
HasNearestLocation)
230+
SWIFT_REQUEST(TypeChecker, TypeAliasRequirementsRequest,
231+
ArrayRef<Requirement>(ProtocolDecl *), Cached,
232+
HasNearestLocation)
230233
SWIFT_REQUEST(TypeChecker, ProtocolDependenciesRequest,
231234
ArrayRef<ProtocolDecl *>(ProtocolDecl *), Cached,
232235
HasNearestLocation)

lib/AST/Decl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5260,6 +5260,13 @@ ProtocolDecl::getStructuralRequirements() const {
52605260
None);
52615261
}
52625262

5263+
ArrayRef<Requirement>
5264+
ProtocolDecl::getTypeAliasRequirements() const {
5265+
return evaluateOrDefault(getASTContext().evaluator,
5266+
TypeAliasRequirementsRequest { const_cast<ProtocolDecl *>(this) },
5267+
None);
5268+
}
5269+
52635270
ArrayRef<ProtocolDecl *>
52645271
ProtocolDecl::getProtocolDependencies() const {
52655272
return evaluateOrDefault(getASTContext().evaluator,

lib/AST/RequirementMachine/RequirementLowering.cpp

Lines changed: 259 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "RequirementLowering.h"
2626
#include "swift/AST/ASTContext.h"
2727
#include "swift/AST/Decl.h"
28+
#include "swift/AST/DiagnosticsSema.h"
2829
#include "swift/AST/ExistentialLayout.h"
2930
#include "swift/AST/Requirement.h"
3031
#include "swift/AST/TypeCheckRequests.h"
@@ -343,6 +344,258 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
343344
return ctx.AllocateCopy(result);
344345
}
345346

347+
ArrayRef<Requirement>
348+
TypeAliasRequirementsRequest::evaluate(Evaluator &evaluator,
349+
ProtocolDecl *proto) const {
350+
// @objc protocols don't have associated types, so all of the below
351+
// becomes a trivial no-op.
352+
if (proto->isObjC())
353+
return ArrayRef<Requirement>();
354+
355+
assert(!proto->hasLazyRequirementSignature());
356+
357+
SmallVector<Requirement, 2> result;
358+
359+
auto &ctx = proto->getASTContext();
360+
361+
// In Verify mode, the GenericSignatureBuilder will emit the same diagnostics.
362+
bool emitDiagnostics =
363+
(ctx.LangOpts.RequirementMachineProtocolSignatures ==
364+
RequirementMachineMode::Enabled);
365+
366+
// Collect all typealiases from inherited protocols recursively.
367+
llvm::MapVector<Identifier, TinyPtrVector<TypeDecl *>> inheritedTypeDecls;
368+
for (auto *inheritedProto : ctx.getRewriteContext().getInheritedProtocols(proto)) {
369+
for (auto req : inheritedProto->getMembers()) {
370+
if (auto *typeReq = dyn_cast<TypeDecl>(req)) {
371+
// Ignore generic typealiases.
372+
if (auto typeAliasReq = dyn_cast<TypeAliasDecl>(typeReq))
373+
if (typeAliasReq->isGeneric())
374+
continue;
375+
376+
inheritedTypeDecls[typeReq->getName()].push_back(typeReq);
377+
}
378+
}
379+
}
380+
381+
auto getStructuralType = [](TypeDecl *typeDecl) -> Type {
382+
if (auto typealias = dyn_cast<TypeAliasDecl>(typeDecl)) {
383+
if (typealias->getUnderlyingTypeRepr() != nullptr) {
384+
auto type = typealias->getStructuralType();
385+
if (auto *aliasTy = cast<TypeAliasType>(type.getPointer()))
386+
return aliasTy->getSinglyDesugaredType();
387+
return type;
388+
}
389+
return typealias->getUnderlyingType();
390+
}
391+
392+
return typeDecl->getDeclaredInterfaceType();
393+
};
394+
395+
// An inferred same-type requirement between the two type declarations
396+
// within this protocol or a protocol it inherits.
397+
auto recordInheritedTypeRequirement = [&](TypeDecl *first, TypeDecl *second) {
398+
desugarSameTypeRequirement(getStructuralType(first),
399+
getStructuralType(second), result);
400+
};
401+
402+
// Local function to find the insertion point for the protocol's "where"
403+
// clause, as well as the string to start the insertion ("where" or ",");
404+
auto getProtocolWhereLoc = [&]() -> Located<const char *> {
405+
// Already has a trailing where clause.
406+
if (auto trailing = proto->getTrailingWhereClause())
407+
return { ", ", trailing->getRequirements().back().getSourceRange().End };
408+
409+
// Inheritance clause.
410+
return { " where ", proto->getInherited().back().getSourceRange().End };
411+
};
412+
413+
// Retrieve the set of requirements that a given associated type declaration
414+
// produces, in the form that would be seen in the where clause.
415+
const auto getAssociatedTypeReqs = [&](const AssociatedTypeDecl *assocType,
416+
const char *start) {
417+
std::string result;
418+
{
419+
llvm::raw_string_ostream out(result);
420+
out << start;
421+
interleave(assocType->getInherited(), [&](TypeLoc inheritedType) {
422+
out << assocType->getName() << ": ";
423+
if (auto inheritedTypeRepr = inheritedType.getTypeRepr())
424+
inheritedTypeRepr->print(out);
425+
else
426+
inheritedType.getType().print(out);
427+
}, [&] {
428+
out << ", ";
429+
});
430+
431+
if (const auto whereClause = assocType->getTrailingWhereClause()) {
432+
if (!assocType->getInherited().empty())
433+
out << ", ";
434+
435+
whereClause->print(out, /*printWhereKeyword*/false);
436+
}
437+
}
438+
return result;
439+
};
440+
441+
// Retrieve the requirement that a given typealias introduces when it
442+
// overrides an inherited associated type with the same name, as a string
443+
// suitable for use in a where clause.
444+
auto getConcreteTypeReq = [&](TypeDecl *type, const char *start) {
445+
std::string result;
446+
{
447+
llvm::raw_string_ostream out(result);
448+
out << start;
449+
out << type->getName() << " == ";
450+
if (auto typealias = dyn_cast<TypeAliasDecl>(type)) {
451+
if (auto underlyingTypeRepr = typealias->getUnderlyingTypeRepr())
452+
underlyingTypeRepr->print(out);
453+
else
454+
typealias->getUnderlyingType().print(out);
455+
} else {
456+
type->print(out);
457+
}
458+
}
459+
return result;
460+
};
461+
462+
for (auto assocTypeDecl : proto->getAssociatedTypeMembers()) {
463+
// Check whether we inherited any types with the same name.
464+
auto knownInherited =
465+
inheritedTypeDecls.find(assocTypeDecl->getName());
466+
if (knownInherited == inheritedTypeDecls.end()) continue;
467+
468+
bool shouldWarnAboutRedeclaration =
469+
emitDiagnostics &&
470+
!assocTypeDecl->getAttrs().hasAttribute<NonOverrideAttr>() &&
471+
!assocTypeDecl->getAttrs().hasAttribute<OverrideAttr>() &&
472+
!assocTypeDecl->hasDefaultDefinitionType() &&
473+
(!assocTypeDecl->getInherited().empty() ||
474+
assocTypeDecl->getTrailingWhereClause() ||
475+
ctx.LangOpts.WarnImplicitOverrides);
476+
for (auto inheritedType : knownInherited->second) {
477+
// If we have inherited associated type...
478+
if (auto inheritedAssocTypeDecl =
479+
dyn_cast<AssociatedTypeDecl>(inheritedType)) {
480+
// Complain about the first redeclaration.
481+
if (shouldWarnAboutRedeclaration) {
482+
auto inheritedFromProto = inheritedAssocTypeDecl->getProtocol();
483+
auto fixItWhere = getProtocolWhereLoc();
484+
ctx.Diags.diagnose(assocTypeDecl,
485+
diag::inherited_associated_type_redecl,
486+
assocTypeDecl->getName(),
487+
inheritedFromProto->getDeclaredInterfaceType())
488+
.fixItInsertAfter(
489+
fixItWhere.Loc,
490+
getAssociatedTypeReqs(assocTypeDecl, fixItWhere.Item))
491+
.fixItRemove(assocTypeDecl->getSourceRange());
492+
493+
ctx.Diags.diagnose(inheritedAssocTypeDecl, diag::decl_declared_here,
494+
inheritedAssocTypeDecl->getName());
495+
496+
shouldWarnAboutRedeclaration = false;
497+
}
498+
499+
continue;
500+
}
501+
502+
if (emitDiagnostics) {
503+
// We inherited a type; this associated type will be identical
504+
// to that typealias.
505+
auto inheritedOwningDecl =
506+
inheritedType->getDeclContext()->getSelfNominalTypeDecl();
507+
ctx.Diags.diagnose(assocTypeDecl,
508+
diag::associated_type_override_typealias,
509+
assocTypeDecl->getName(),
510+
inheritedOwningDecl->getDescriptiveKind(),
511+
inheritedOwningDecl->getDeclaredInterfaceType());
512+
}
513+
514+
recordInheritedTypeRequirement(assocTypeDecl, inheritedType);
515+
}
516+
517+
inheritedTypeDecls.erase(knownInherited);
518+
}
519+
520+
// Check all remaining inherited type declarations to determine if
521+
// this protocol has a non-associated-type type with the same name.
522+
inheritedTypeDecls.remove_if(
523+
[&](const std::pair<Identifier, TinyPtrVector<TypeDecl *>> &inherited) {
524+
const auto name = inherited.first;
525+
for (auto found : proto->lookupDirect(name)) {
526+
// We only want concrete type declarations.
527+
auto type = dyn_cast<TypeDecl>(found);
528+
if (!type || isa<AssociatedTypeDecl>(type)) continue;
529+
530+
// Ignore nominal types. They're always invalid declarations.
531+
if (isa<NominalTypeDecl>(type))
532+
continue;
533+
534+
// ... from the same module as the protocol.
535+
if (type->getModuleContext() != proto->getModuleContext()) continue;
536+
537+
// Ignore types defined in constrained extensions; their equivalence
538+
// to the associated type would have to be conditional, which we cannot
539+
// model.
540+
if (auto ext = dyn_cast<ExtensionDecl>(type->getDeclContext())) {
541+
if (ext->isConstrainedExtension()) continue;
542+
}
543+
544+
// We found something.
545+
bool shouldWarnAboutRedeclaration = emitDiagnostics;
546+
547+
for (auto inheritedType : inherited.second) {
548+
// If we have inherited associated type...
549+
if (auto inheritedAssocTypeDecl =
550+
dyn_cast<AssociatedTypeDecl>(inheritedType)) {
551+
// Infer a same-type requirement between the typealias' underlying
552+
// type and the inherited associated type.
553+
recordInheritedTypeRequirement(inheritedAssocTypeDecl, type);
554+
555+
// Warn that one should use where clauses for this.
556+
if (shouldWarnAboutRedeclaration) {
557+
auto inheritedFromProto = inheritedAssocTypeDecl->getProtocol();
558+
auto fixItWhere = getProtocolWhereLoc();
559+
ctx.Diags.diagnose(type,
560+
diag::typealias_override_associated_type,
561+
name,
562+
inheritedFromProto->getDeclaredInterfaceType())
563+
.fixItInsertAfter(fixItWhere.Loc,
564+
getConcreteTypeReq(type, fixItWhere.Item))
565+
.fixItRemove(type->getSourceRange());
566+
ctx.Diags.diagnose(inheritedAssocTypeDecl, diag::decl_declared_here,
567+
inheritedAssocTypeDecl->getName());
568+
569+
shouldWarnAboutRedeclaration = false;
570+
}
571+
572+
continue;
573+
}
574+
575+
// Two typealiases that should be the same.
576+
recordInheritedTypeRequirement(inheritedType, type);
577+
}
578+
579+
// We can remove this entry.
580+
return true;
581+
}
582+
583+
return false;
584+
});
585+
586+
// Infer same-type requirements among inherited type declarations.
587+
for (auto &entry : inheritedTypeDecls) {
588+
if (entry.second.size() < 2) continue;
589+
590+
auto firstDecl = entry.second.front();
591+
for (auto otherDecl : ArrayRef<TypeDecl *>(entry.second).slice(1)) {
592+
recordInheritedTypeRequirement(firstDecl, otherDecl);
593+
}
594+
}
595+
596+
return ctx.AllocateCopy(result);
597+
}
598+
346599
ArrayRef<ProtocolDecl *>
347600
ProtocolDependenciesRequest::evaluate(Evaluator &evaluator,
348601
ProtocolDecl *proto) const {
@@ -635,11 +888,13 @@ void RuleBuilder::collectRulesFromReferencedProtocols() {
635888
// we can trigger the computation of the requirement signatures of the
636889
// next component recursively.
637890
if (ProtocolMap[proto]) {
638-
for (auto req : proto->getStructuralRequirements()) {
639-
// FIXME: Keep source location information around for redundancy
640-
// diagnostics.
891+
// FIXME: Keep source location information around for redundancy
892+
// diagnostics.
893+
for (auto req : proto->getStructuralRequirements())
641894
addRequirement(req.req.getCanonical(), proto);
642-
}
895+
896+
for (auto req : proto->getTypeAliasRequirements())
897+
addRequirement(req.getCanonical(), proto);
643898
} else {
644899
for (auto req : proto->getRequirementSignature())
645900
addRequirement(req.getCanonical(), proto);

0 commit comments

Comments
 (0)