Skip to content

Commit 390fda6

Browse files
authored
Merge pull request #60517 from slavapestov/rqm-occurs-check-5.7
RequirementMachine: Fix crash-on-invalid with recursive same-type requirements [5.7]
2 parents 8faef1a + 1daea42 commit 390fda6

13 files changed

+274
-60
lines changed

lib/AST/RequirementMachine/ConcreteContraction.cpp

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,23 @@ using namespace rewriting;
154154
/// Strip associated types from types used as keys to erase differences between
155155
/// resolved types coming from the parent generic signature and unresolved types
156156
/// coming from user-written requirements.
157-
static CanType stripBoundDependentMemberTypes(Type t) {
157+
static Type stripBoundDependentMemberTypes(Type t) {
158158
if (auto *depMemTy = t->getAs<DependentMemberType>()) {
159-
return CanType(DependentMemberType::get(
159+
return DependentMemberType::get(
160160
stripBoundDependentMemberTypes(depMemTy->getBase()),
161-
depMemTy->getName()));
161+
depMemTy->getName());
162162
}
163163

164-
return t->getCanonicalType();
164+
return t;
165+
}
166+
167+
/// Returns true if \p lhs appears as the base of a member type in \p rhs.
168+
static bool typeOccursIn(Type lhs, Type rhs) {
169+
return rhs.findIf([lhs](Type t) -> bool {
170+
if (auto *memberType = t->getAs<DependentMemberType>())
171+
return memberType->getBase()->isEqual(lhs);
172+
return false;
173+
});
165174
}
166175

167176
namespace {
@@ -232,17 +241,18 @@ Optional<Type> ConcreteContraction::substTypeParameterRec(
232241
// losing the requirement.
233242
if (position == Position::BaseType ||
234243
position == Position::ConformanceRequirement) {
244+
auto key = stripBoundDependentMemberTypes(type)->getCanonicalType();
235245

236246
Type concreteType;
237247
{
238-
auto found = ConcreteTypes.find(stripBoundDependentMemberTypes(type));
248+
auto found = ConcreteTypes.find(key);
239249
if (found != ConcreteTypes.end() && found->second.size() == 1)
240250
concreteType = *found->second.begin();
241251
}
242252

243253
Type superclass;
244254
{
245-
auto found = Superclasses.find(stripBoundDependentMemberTypes(type));
255+
auto found = Superclasses.find(key);
246256
if (found != Superclasses.end() && found->second.size() == 1)
247257
superclass = *found->second.begin();
248258
}
@@ -392,7 +402,8 @@ ConcreteContraction::substRequirement(const Requirement &req) const {
392402
// 'T : Sendable' would be incorrect; we want to ensure that we only admit
393403
// subclasses of 'C' which are 'Sendable'.
394404
bool allowMissing = false;
395-
if (ConcreteTypes.count(stripBoundDependentMemberTypes(firstType)) > 0)
405+
auto key = stripBoundDependentMemberTypes(firstType)->getCanonicalType();
406+
if (ConcreteTypes.count(key) > 0)
396407
allowMissing = true;
397408

398409
if (!substFirstType->isTypeParameter()) {
@@ -449,17 +460,18 @@ hasResolvedMemberTypeOfInterestingParameter(Type type) const {
449460
if (memberTy->getAssocType() == nullptr)
450461
return false;
451462

452-
auto baseTy = memberTy->getBase();
463+
auto key = stripBoundDependentMemberTypes(memberTy->getBase())
464+
->getCanonicalType();
453465
Type concreteType;
454466
{
455-
auto found = ConcreteTypes.find(stripBoundDependentMemberTypes(baseTy));
467+
auto found = ConcreteTypes.find(key);
456468
if (found != ConcreteTypes.end() && found->second.size() == 1)
457469
return true;
458470
}
459471

460472
Type superclass;
461473
{
462-
auto found = Superclasses.find(stripBoundDependentMemberTypes(baseTy));
474+
auto found = Superclasses.find(key);
463475
if (found != Superclasses.end() && found->second.size() == 1)
464476
return true;
465477
}
@@ -496,14 +508,14 @@ bool ConcreteContraction::preserveSameTypeRequirement(
496508

497509
// One of the parent types of this type parameter should be subject
498510
// to a superclass requirement.
499-
auto type = req.getFirstType();
511+
auto type = stripBoundDependentMemberTypes(req.getFirstType())
512+
->getCanonicalType();
500513
while (true) {
501-
if (Superclasses.find(stripBoundDependentMemberTypes(type))
502-
!= Superclasses.end())
514+
if (Superclasses.find(type) != Superclasses.end())
503515
break;
504516

505-
if (auto *memberType = type->getAs<DependentMemberType>()) {
506-
type = memberType->getBase();
517+
if (auto memberType = dyn_cast<DependentMemberType>(type)) {
518+
type = memberType.getBase();
507519
continue;
508520
}
509521

@@ -546,23 +558,41 @@ bool ConcreteContraction::performConcreteContraction(
546558
if (constraintType->isTypeParameter())
547559
break;
548560

549-
ConcreteTypes[stripBoundDependentMemberTypes(subjectType)]
550-
.insert(constraintType);
561+
subjectType = stripBoundDependentMemberTypes(subjectType);
562+
if (typeOccursIn(subjectType,
563+
stripBoundDependentMemberTypes(constraintType))) {
564+
if (Debug) {
565+
llvm::dbgs() << "@ Subject type of same-type requirement "
566+
<< subjectType << " == " << constraintType << " "
567+
<< "occurs in the constraint type, skipping\n";
568+
}
569+
break;
570+
}
571+
ConcreteTypes[subjectType->getCanonicalType()].insert(constraintType);
551572
break;
552573
}
553574
case RequirementKind::Superclass: {
554575
auto constraintType = req.req.getSecondType();
555576
assert(!constraintType->isTypeParameter() &&
556577
"You forgot to call desugarRequirement()");
557578

558-
Superclasses[stripBoundDependentMemberTypes(subjectType)]
559-
.insert(constraintType);
579+
subjectType = stripBoundDependentMemberTypes(subjectType);
580+
if (typeOccursIn(subjectType,
581+
stripBoundDependentMemberTypes(constraintType))) {
582+
if (Debug) {
583+
llvm::dbgs() << "@ Subject type of superclass requirement "
584+
<< subjectType << " : " << constraintType << " "
585+
<< "occurs in the constraint type, skipping\n";
586+
}
587+
break;
588+
}
589+
Superclasses[subjectType->getCanonicalType()].insert(constraintType);
560590
break;
561591
}
562592
case RequirementKind::Conformance: {
563593
auto *protoDecl = req.req.getProtocolDecl();
564-
Conformances[stripBoundDependentMemberTypes(subjectType)]
565-
.push_back(protoDecl);
594+
subjectType = stripBoundDependentMemberTypes(subjectType);
595+
Conformances[subjectType->getCanonicalType()].push_back(protoDecl);
566596

567597
break;
568598
}
@@ -588,7 +618,7 @@ bool ConcreteContraction::performConcreteContraction(
588618
if (auto otherSuperclassTy = proto->getSuperclass()) {
589619
if (Debug) {
590620
llvm::dbgs() << "@ Subject type of superclass requirement "
591-
<< "τ_" << subjectType << " : " << superclassTy
621+
<< subjectType << " : " << superclassTy
592622
<< " conforms to "<< proto->getName()
593623
<< " which has a superclass bound "
594624
<< otherSuperclassTy << "\n";

lib/AST/RequirementMachine/Diagnostics.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,26 @@ bool swift::rewriting::diagnoseRequirementErrors(
137137
break;
138138
}
139139

140+
case RequirementError::Kind::RecursiveRequirement: {
141+
auto requirement = error.requirement;
142+
143+
if (requirement.hasError())
144+
break;
145+
146+
assert(requirement.getKind() == RequirementKind::SameType ||
147+
requirement.getKind() == RequirementKind::Superclass);
148+
149+
ctx.Diags.diagnose(loc,
150+
(requirement.getKind() == RequirementKind::SameType ?
151+
diag::recursive_same_type_constraint :
152+
diag::recursive_superclass_constraint),
153+
requirement.getFirstType(),
154+
requirement.getSecondType());
155+
156+
diagnosedError = true;
157+
break;
158+
}
159+
140160
case RequirementError::Kind::RedundantRequirement: {
141161
// We only emit redundant requirement warnings if the user passed
142162
// the -warn-redundant-requirements frontend flag.
@@ -390,7 +410,7 @@ getRequirementForDiagnostics(Type subject, Symbol property,
390410
}
391411
}
392412

393-
void RewriteSystem::computeConflictDiagnostics(
413+
void RewriteSystem::computeConflictingRequirementDiagnostics(
394414
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc,
395415
const PropertyMap &propertyMap,
396416
TypeArrayView<GenericTypeParamType> genericParams) {
@@ -427,11 +447,30 @@ void RewriteSystem::computeConflictDiagnostics(
427447
}
428448
}
429449

450+
void RewriteSystem::computeRecursiveRequirementDiagnostics(
451+
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc,
452+
const PropertyMap &propertyMap,
453+
TypeArrayView<GenericTypeParamType> genericParams) {
454+
for (unsigned ruleID : RecursiveRules) {
455+
const auto &rule = getRule(ruleID);
456+
457+
assert(isInMinimizationDomain(rule.getRHS()[0].getRootProtocol()));
458+
459+
Type subjectType = propertyMap.getTypeForTerm(rule.getRHS(), genericParams);
460+
errors.push_back(RequirementError::forRecursiveRequirement(
461+
getRequirementForDiagnostics(subjectType, *rule.isPropertyRule(),
462+
propertyMap, genericParams, MutableTerm()),
463+
signatureLoc));
464+
}
465+
}
466+
430467
void RequirementMachine::computeRequirementDiagnostics(
431468
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc) {
432469
System.computeRedundantRequirementDiagnostics(errors);
433-
System.computeConflictDiagnostics(errors, signatureLoc, Map,
434-
getGenericParams());
470+
System.computeConflictingRequirementDiagnostics(errors, signatureLoc, Map,
471+
getGenericParams());
472+
System.computeRecursiveRequirementDiagnostics(errors, signatureLoc, Map,
473+
getGenericParams());
435474
}
436475

437476
std::string RequirementMachine::getRuleAsStringForDiagnostics(

lib/AST/RequirementMachine/Diagnostics.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct RequirementError {
3636
InvalidRequirementSubject,
3737
/// A pair of conflicting requirements, T == Int, T == String
3838
ConflictingRequirement,
39+
/// A recursive requirement, e.g. T == G<T.A>.
40+
RecursiveRequirement,
3941
/// A redundant requirement, e.g. T == T.
4042
RedundantRequirement,
4143
} kind;
@@ -86,6 +88,11 @@ struct RequirementError {
8688
SourceLoc loc) {
8789
return {Kind::RedundantRequirement, req, loc};
8890
}
91+
92+
static RequirementError forRecursiveRequirement(Requirement req,
93+
SourceLoc loc) {
94+
return {Kind::RecursiveRequirement, req, loc};
95+
}
8996
};
9097

9198
/// Policy for the fixit that transforms 'T : S' where 'S' is not a protocol

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,41 @@ void RewriteSystem::propagateRedundantRequirementIDs() {
179179
}
180180
}
181181

182+
/// Find concrete type or superclass rules where the right hand side occurs as a
183+
/// proper prefix of one of its substitutions.
184+
///
185+
/// eg, (T.[concrete: G<T.[P:A]>] => T).
186+
void RewriteSystem::computeRecursiveRules() {
187+
for (unsigned ruleID = FirstLocalRule, e = Rules.size();
188+
ruleID < e; ++ruleID) {
189+
auto &rule = getRule(ruleID);
190+
191+
if (rule.isPermanent() ||
192+
rule.isRedundant())
193+
continue;
194+
195+
auto optSymbol = rule.isPropertyRule();
196+
if (!optSymbol)
197+
continue;
198+
199+
auto kind = optSymbol->getKind();
200+
if (kind != Symbol::Kind::ConcreteType &&
201+
kind != Symbol::Kind::Superclass) {
202+
continue;
203+
}
204+
205+
auto rhs = rule.getRHS();
206+
for (auto term : optSymbol->getSubstitutions()) {
207+
if (term.size() > rhs.size() &&
208+
std::equal(rhs.begin(), rhs.end(), term.begin())) {
209+
RecursiveRules.push_back(ruleID);
210+
rule.markRecursive();
211+
break;
212+
}
213+
}
214+
}
215+
}
216+
182217
/// Find a rule to delete by looking through all loops for rewrite rules appearing
183218
/// once in empty context. Returns a pair consisting of a loop ID and a rule ID,
184219
/// otherwise returns None.
@@ -579,6 +614,7 @@ void RewriteSystem::minimizeRewriteSystem() {
579614
});
580615

581616
propagateRedundantRequirementIDs();
617+
computeRecursiveRules();
582618

583619
// Check invariants after homotopy reduction.
584620
verifyRewriteLoops();
@@ -628,7 +664,7 @@ GenericSignatureErrors RewriteSystem::getErrors() const {
628664
rule.containsUnresolvedSymbols())
629665
result |= GenericSignatureErrorFlags::HasInvalidRequirements;
630666

631-
if (rule.isConflicting())
667+
if (rule.isConflicting() || rule.isRecursive())
632668
result |= GenericSignatureErrorFlags::HasInvalidRequirements;
633669

634670
if (!rule.isRedundant())

lib/AST/RequirementMachine/RequirementBuilder.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ void ConnectedComponent::buildRequirements(Type subjectType,
8686
subjectType = constraintType;
8787
}
8888

89-
// For compatibility with the old GenericSignatureBuilder, drop requirements
90-
// containing ErrorTypes.
91-
} else if (!ConcreteType->hasError()) {
89+
} else {
9290
// If there are multiple protocol typealiases in the connected component,
9391
// lower them all to a series of identical concrete-type aliases.
9492
for (auto name : Aliases) {
@@ -163,6 +161,14 @@ class RequirementBuilder {
163161

164162
} // end namespace
165163

164+
static Type replaceTypeParametersWithErrorTypes(Type type) {
165+
return type.transformRec([](Type t) -> Optional<Type> {
166+
if (t->isTypeParameter())
167+
return ErrorType::get(t->getASTContext());
168+
return None;
169+
});
170+
}
171+
166172
void RequirementBuilder::addRequirementRules(ArrayRef<unsigned> rules) {
167173
// Convert a rewrite rule into a requirement.
168174
auto createRequirementFromRule = [&](const Rule &rule) {
@@ -190,15 +196,12 @@ void RequirementBuilder::addRequirementRules(ArrayRef<unsigned> rules) {
190196
return;
191197
}
192198

193-
// Requirements containing error types originate from invalid code
194-
// and should not appear in the generic signature.
195-
if (prop->getConcreteType()->hasError())
196-
return;
197-
198199
Type superclassType = Map.getTypeFromSubstitutionSchema(
199200
prop->getConcreteType(),
200201
prop->getSubstitutions(),
201202
GenericParams, MutableTerm());
203+
if (rule.isRecursive())
204+
superclassType = replaceTypeParametersWithErrorTypes(superclassType);
202205

203206
if (ReconstituteSugar)
204207
superclassType = superclassType->reconstituteSugar(/*recursive=*/true);
@@ -216,15 +219,12 @@ void RequirementBuilder::addRequirementRules(ArrayRef<unsigned> rules) {
216219
return;
217220
}
218221

219-
// Requirements containing error types originate from invalid code
220-
// and should not appear in the generic signature.
221-
if (prop->getConcreteType()->hasError())
222-
return;
223-
224222
Type concreteType = Map.getTypeFromSubstitutionSchema(
225223
prop->getConcreteType(),
226224
prop->getSubstitutions(),
227225
GenericParams, MutableTerm());
226+
if (rule.isRecursive())
227+
concreteType = replaceTypeParametersWithErrorTypes(concreteType);
228228

229229
if (ReconstituteSugar)
230230
concreteType = concreteType->reconstituteSugar(/*recursive=*/true);
@@ -291,15 +291,12 @@ void RequirementBuilder::addTypeAliasRules(ArrayRef<unsigned> rules) {
291291
continue;
292292
}
293293

294-
// Requirements containing error types originate from invalid code
295-
// and should not appear in the generic signature.
296-
if (prop->getConcreteType()->hasError())
297-
continue;
298-
299294
Type concreteType = Map.getTypeFromSubstitutionSchema(
300295
prop->getConcreteType(),
301296
prop->getSubstitutions(),
302297
GenericParams, MutableTerm());
298+
if (rule.isRecursive())
299+
concreteType = replaceTypeParametersWithErrorTypes(concreteType);
303300

304301
if (ReconstituteSugar)
305302
concreteType = concreteType->reconstituteSugar(/*recursive=*/true);

0 commit comments

Comments
 (0)