Skip to content

Commit ceeee45

Browse files
authored
Merge pull request swiftlang#36631 from xedin/conformance-perf-experiment
[Perf][CSSimplify] Transfer conformance requirements of a parameter to an argument
2 parents ba1fc17 + 62c5e18 commit ceeee45

File tree

14 files changed

+240
-12
lines changed

14 files changed

+240
-12
lines changed

include/swift/Sema/Constraint.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ enum class ConstraintKind : char {
190190
/// The first type is a property wrapper with a wrapped-value type
191191
/// equal to the second type.
192192
PropertyWrapper,
193+
/// The first type (or its optional or pointer version) must conform to a
194+
/// second type (protocol type). This is not a direct requirement but one
195+
/// inferred from a conversion, so the check is more relax comparing to
196+
/// `ConformsTo`.
197+
TransitivelyConformsTo,
193198
};
194199

195200
/// Classification of the different kinds of constraints.
@@ -579,6 +584,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
579584
case ConstraintKind::OperatorArgumentConversion:
580585
case ConstraintKind::ConformsTo:
581586
case ConstraintKind::LiteralConformsTo:
587+
case ConstraintKind::TransitivelyConformsTo:
582588
case ConstraintKind::CheckedCast:
583589
case ConstraintKind::SelfObjectOfProtocol:
584590
case ConstraintKind::ApplicableFunction:

include/swift/Sema/ConstraintSystem.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4513,6 +4513,12 @@ class ConstraintSystem {
45134513
ConstraintLocatorBuilder locator,
45144514
TypeMatchOptions flags);
45154515

4516+
/// Similar to \c simplifyConformsToConstraint but also checks for
4517+
/// optional and pointer derived a given type.
4518+
SolutionKind simplifyTransitivelyConformsTo(Type type, Type protocol,
4519+
ConstraintLocatorBuilder locator,
4520+
TypeMatchOptions flags);
4521+
45164522
/// Attempt to simplify a checked-cast constraint.
45174523
SolutionKind simplifyCheckedCastConstraint(Type fromType, Type toType,
45184524
TypeMatchOptions flags,

lib/Sema/CSBindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,12 @@ void PotentialBindings::infer(Constraint *constraint) {
13121312
// Constraints from which we can't do anything.
13131313
break;
13141314

1315+
// For now let's avoid inferring protocol requirements from
1316+
// this constraint, but in the future we could do that to
1317+
// to filter bindings.
1318+
case ConstraintKind::TransitivelyConformsTo:
1319+
break;
1320+
13151321
case ConstraintKind::DynamicTypeOf: {
13161322
// Direct binding of the left-hand side could result
13171323
// in `DynamicTypeOf` failure if right-hand side is

lib/Sema/CSSimplify.cpp

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
//===----------------------------------------------------------------------===//
1717

1818
#include "CSDiagnostics.h"
19+
#include "swift/AST/Decl.h"
1920
#include "swift/AST/ExistentialLayout.h"
2021
#include "swift/AST/GenericEnvironment.h"
2122
#include "swift/AST/GenericSignature.h"
@@ -1433,6 +1434,34 @@ ConstraintSystem::TypeMatchResult constraints::matchCallArguments(
14331434
assert(!argsWithLabels[argIdx].isAutoClosure() ||
14341435
isSynthesizedArgument(argument));
14351436

1437+
// If parameter is a generic parameter, let's copy its
1438+
// conformance requirements (if any), to the argument
1439+
// be able to filter mismatching choices earlier.
1440+
if (auto *typeVar = paramTy->getAs<TypeVariableType>()) {
1441+
auto *locator = typeVar->getImpl().getLocator();
1442+
if (locator->isForGenericParameter()) {
1443+
auto &CG = cs.getConstraintGraph();
1444+
1445+
auto isTransferableConformance = [&typeVar](Constraint *constraint) {
1446+
if (constraint->getKind() != ConstraintKind::ConformsTo)
1447+
return false;
1448+
1449+
auto requirementTy = constraint->getFirstType();
1450+
if (!requirementTy->isEqual(typeVar))
1451+
return false;
1452+
1453+
return constraint->getSecondType()->is<ProtocolType>();
1454+
};
1455+
1456+
for (auto *constraint : CG[typeVar].getConstraints()) {
1457+
if (isTransferableConformance(constraint))
1458+
cs.addConstraint(ConstraintKind::TransitivelyConformsTo, argTy,
1459+
constraint->getSecondType(),
1460+
constraint->getLocator());
1461+
}
1462+
}
1463+
}
1464+
14361465
cs.addConstraint(
14371466
subKind, argTy, paramTy,
14381467
matchingAutoClosureResult
@@ -1542,6 +1571,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
15421571
case ConstraintKind::BindOverload:
15431572
case ConstraintKind::CheckedCast:
15441573
case ConstraintKind::ConformsTo:
1574+
case ConstraintKind::TransitivelyConformsTo:
15451575
case ConstraintKind::Defaultable:
15461576
case ConstraintKind::Disjunction:
15471577
case ConstraintKind::DynamicTypeOf:
@@ -1682,6 +1712,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1,
16821712
case ConstraintKind::BindOverload:
16831713
case ConstraintKind::CheckedCast:
16841714
case ConstraintKind::ConformsTo:
1715+
case ConstraintKind::TransitivelyConformsTo:
16851716
case ConstraintKind::Defaultable:
16861717
case ConstraintKind::Disjunction:
16871718
case ConstraintKind::DynamicTypeOf:
@@ -2072,6 +2103,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
20722103
case ConstraintKind::BindOverload:
20732104
case ConstraintKind::CheckedCast:
20742105
case ConstraintKind::ConformsTo:
2106+
case ConstraintKind::TransitivelyConformsTo:
20752107
case ConstraintKind::Defaultable:
20762108
case ConstraintKind::Disjunction:
20772109
case ConstraintKind::DynamicTypeOf:
@@ -4983,6 +5015,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
49835015
case ConstraintKind::BridgingConversion:
49845016
case ConstraintKind::CheckedCast:
49855017
case ConstraintKind::ConformsTo:
5018+
case ConstraintKind::TransitivelyConformsTo:
49865019
case ConstraintKind::Defaultable:
49875020
case ConstraintKind::Disjunction:
49885021
case ConstraintKind::DynamicTypeOf:
@@ -6208,6 +6241,128 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
62086241
return SolutionKind::Error;
62096242
}
62106243

6244+
ConstraintSystem::SolutionKind ConstraintSystem::simplifyTransitivelyConformsTo(
6245+
Type type, Type protocolTy, ConstraintLocatorBuilder locator,
6246+
TypeMatchOptions flags) {
6247+
auto &ctx = getASTContext();
6248+
6249+
// Since this is a performance optimization, let's ignore it
6250+
// in diagnostic mode.
6251+
if (shouldAttemptFixes())
6252+
return SolutionKind::Solved;
6253+
6254+
auto formUnsolved = [&]() {
6255+
// If we're supposed to generate constraints, do so.
6256+
if (flags.contains(TMF_GenerateConstraints)) {
6257+
auto *conformance =
6258+
Constraint::create(*this, ConstraintKind::TransitivelyConformsTo,
6259+
type, protocolTy, getConstraintLocator(locator));
6260+
6261+
addUnsolvedConstraint(conformance);
6262+
return SolutionKind::Solved;
6263+
}
6264+
6265+
return SolutionKind::Unsolved;
6266+
};
6267+
6268+
auto resolvedTy = getFixedTypeRecursive(type, /*wantRValue=*/true);
6269+
if (resolvedTy->isTypeVariableOrMember())
6270+
return formUnsolved();
6271+
6272+
// If the composition consists of a class + protocol,
6273+
// we can't check conformance of the argument because
6274+
// parameter could pick one of the components.
6275+
if (resolvedTy.findIf(
6276+
[](Type type) { return type->is<ProtocolCompositionType>(); }))
6277+
return SolutionKind::Solved;
6278+
6279+
// All bets are off for pointers, there are multiple combinations
6280+
// to check and it doesn't see worth to do that upfront.
6281+
{
6282+
PointerTypeKind pointerKind;
6283+
if (resolvedTy->getAnyPointerElementType(pointerKind))
6284+
return SolutionKind::Solved;
6285+
}
6286+
6287+
auto *protocol = protocolTy->castTo<ProtocolType>()->getDecl();
6288+
6289+
auto *M = DC->getParentModule();
6290+
6291+
// First, let's check whether the type itself conforms,
6292+
// if it does - we are done.
6293+
if (M->lookupConformance(resolvedTy, protocol))
6294+
return SolutionKind::Solved;
6295+
6296+
// If the type doesn't conform, let's check whether
6297+
// an Optional or Unsafe{Mutable}Pointer from it would.
6298+
6299+
SmallVector<Type, 4> typesToCheck;
6300+
6301+
// T -> Optional<T>
6302+
if (!resolvedTy->getOptionalObjectType())
6303+
typesToCheck.push_back(OptionalType::get(resolvedTy));
6304+
6305+
// AnyHashable
6306+
if (auto *anyHashable = ctx.getAnyHashableDecl())
6307+
typesToCheck.push_back(anyHashable->getDeclaredInterfaceType());
6308+
6309+
// Rest of the implicit conversions depend on the resolved type.
6310+
{
6311+
auto getPointerFor = [&ctx](PointerTypeKind ptrKind,
6312+
Optional<Type> elementTy = None) -> Type {
6313+
switch (ptrKind) {
6314+
case PTK_UnsafePointer:
6315+
assert(elementTy);
6316+
return BoundGenericType::get(ctx.getUnsafePointerDecl(),
6317+
/*parent=*/Type(), {*elementTy});
6318+
case PTK_UnsafeMutablePointer:
6319+
assert(elementTy);
6320+
return BoundGenericType::get(ctx.getUnsafeMutablePointerDecl(),
6321+
/*parent=*/Type(), {*elementTy});
6322+
6323+
case PTK_UnsafeRawPointer:
6324+
return ctx.getUnsafeRawPointerDecl()->getDeclaredInterfaceType();
6325+
6326+
case PTK_UnsafeMutableRawPointer:
6327+
return ctx.getUnsafeMutableRawPointerDecl()->getDeclaredInterfaceType();
6328+
6329+
case PTK_AutoreleasingUnsafeMutablePointer:
6330+
llvm_unreachable("no implicit conversion");
6331+
}
6332+
};
6333+
6334+
// String -> UnsafePointer<Void>
6335+
if (auto *string = ctx.getStringDecl()) {
6336+
if (resolvedTy->isEqual(string->getDeclaredInterfaceType())) {
6337+
typesToCheck.push_back(
6338+
getPointerFor(PTK_UnsafePointer, ctx.TheEmptyTupleType));
6339+
}
6340+
}
6341+
6342+
// Array<T> -> Unsafe{Raw}Pointer<T>
6343+
if (auto elt = isArrayType(resolvedTy)) {
6344+
typesToCheck.push_back(getPointerFor(PTK_UnsafePointer, *elt));
6345+
typesToCheck.push_back(getPointerFor(PTK_UnsafeRawPointer, *elt));
6346+
}
6347+
6348+
// inout argument -> UnsafePointer<T>, UnsafeMutablePointer<T>,
6349+
// UnsafeRawPointer, UnsafeMutableRawPointer.
6350+
if (type->is<InOutType>()) {
6351+
typesToCheck.push_back(getPointerFor(PTK_UnsafePointer, resolvedTy));
6352+
typesToCheck.push_back(getPointerFor(PTK_UnsafeMutablePointer, resolvedTy));
6353+
typesToCheck.push_back(getPointerFor(PTK_UnsafeRawPointer));
6354+
typesToCheck.push_back(getPointerFor(PTK_UnsafeMutableRawPointer));
6355+
}
6356+
}
6357+
6358+
return llvm::any_of(typesToCheck,
6359+
[&](Type type) {
6360+
return bool(M->lookupConformance(type, protocol));
6361+
})
6362+
? SolutionKind::Solved
6363+
: SolutionKind::Error;
6364+
}
6365+
62116366
/// Determine the kind of checked cast to perform from the given type to
62126367
/// the given type.
62136368
///
@@ -11284,6 +11439,10 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
1128411439
return simplifyConformsToConstraint(first, second, kind, locator,
1128511440
subflags);
1128611441

11442+
case ConstraintKind::TransitivelyConformsTo:
11443+
return simplifyTransitivelyConformsTo(first, second, locator,
11444+
subflags);
11445+
1128711446
case ConstraintKind::CheckedCast:
1128811447
return simplifyCheckedCastConstraint(first, second, subflags, locator);
1128911448

@@ -11752,6 +11911,13 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
1175211911
constraint.getLocator(),
1175311912
None);
1175411913

11914+
case ConstraintKind::TransitivelyConformsTo:
11915+
return simplifyTransitivelyConformsTo(
11916+
constraint.getFirstType(),
11917+
constraint.getSecondType(),
11918+
constraint.getLocator(),
11919+
None);
11920+
1175511921
case ConstraintKind::CheckedCast: {
1175611922
auto result = simplifyCheckedCastConstraint(constraint.getFirstType(),
1175711923
constraint.getSecondType(),

lib/Sema/CSSolver.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,7 @@ void DisjunctionChoice::propagateConversionInfo(ConstraintSystem &cs) const {
22322232
case ConstraintKind::Defaultable:
22332233
case ConstraintKind::ConformsTo:
22342234
case ConstraintKind::LiteralConformsTo:
2235+
case ConstraintKind::TransitivelyConformsTo:
22352236
return false;
22362237

22372238
default:

lib/Sema/Constraint.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
5656
case ConstraintKind::OperatorArgumentConversion:
5757
case ConstraintKind::ConformsTo:
5858
case ConstraintKind::LiteralConformsTo:
59+
case ConstraintKind::TransitivelyConformsTo:
5960
case ConstraintKind::CheckedCast:
6061
case ConstraintKind::SelfObjectOfProtocol:
6162
case ConstraintKind::DynamicTypeOf:
@@ -124,6 +125,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third,
124125
case ConstraintKind::OperatorArgumentConversion:
125126
case ConstraintKind::ConformsTo:
126127
case ConstraintKind::LiteralConformsTo:
128+
case ConstraintKind::TransitivelyConformsTo:
127129
case ConstraintKind::CheckedCast:
128130
case ConstraintKind::SelfObjectOfProtocol:
129131
case ConstraintKind::DynamicTypeOf:
@@ -241,7 +243,8 @@ Constraint::Constraint(ConstraintKind kind, ConstraintFix *fix, Type first,
241243
ProtocolDecl *Constraint::getProtocol() const {
242244
assert((Kind == ConstraintKind::ConformsTo ||
243245
Kind == ConstraintKind::LiteralConformsTo ||
244-
Kind == ConstraintKind::SelfObjectOfProtocol)
246+
Kind == ConstraintKind::SelfObjectOfProtocol ||
247+
Kind == ConstraintKind::TransitivelyConformsTo)
245248
&& "Not a conformance constraint");
246249
return Types.Second->castTo<ProtocolType>()->getDecl();
247250
}
@@ -259,6 +262,7 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const {
259262
case ConstraintKind::OperatorArgumentConversion:
260263
case ConstraintKind::ConformsTo:
261264
case ConstraintKind::LiteralConformsTo:
265+
case ConstraintKind::TransitivelyConformsTo:
262266
case ConstraintKind::CheckedCast:
263267
case ConstraintKind::DynamicTypeOf:
264268
case ConstraintKind::EscapableFunctionOf:
@@ -357,6 +361,7 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm) const {
357361
Out << " operator arg conv "; break;
358362
case ConstraintKind::ConformsTo: Out << " conforms to "; break;
359363
case ConstraintKind::LiteralConformsTo: Out << " literal conforms to "; break;
364+
case ConstraintKind::TransitivelyConformsTo: Out << " transitive conformance to "; break;
360365
case ConstraintKind::CheckedCast: Out << " checked cast to "; break;
361366
case ConstraintKind::SelfObjectOfProtocol: Out << " Self type of "; break;
362367
case ConstraintKind::ApplicableFunction: Out << " applicable fn "; break;
@@ -599,6 +604,7 @@ gatherReferencedTypeVars(Constraint *constraint,
599604
case ConstraintKind::Defaultable:
600605
case ConstraintKind::ConformsTo:
601606
case ConstraintKind::LiteralConformsTo:
607+
case ConstraintKind::TransitivelyConformsTo:
602608
case ConstraintKind::SelfObjectOfProtocol:
603609
case ConstraintKind::FunctionInput:
604610
case ConstraintKind::FunctionResult:
@@ -659,6 +665,7 @@ Constraint *Constraint::create(ConstraintSystem &cs, ConstraintKind kind,
659665

660666
// Conformance constraints expect an existential on the right-hand side.
661667
assert((kind != ConstraintKind::ConformsTo &&
668+
kind != ConstraintKind::TransitivelyConformsTo &&
662669
kind != ConstraintKind::SelfObjectOfProtocol) ||
663670
second->isExistentialType());
664671

test/Constraints/protocols.swift

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,54 @@ extension UnsafePointer : Trivial {
441441
typealias T = Int
442442
}
443443

444+
extension AnyHashable : Trivial {
445+
typealias T = Int
446+
}
447+
448+
extension UnsafeRawPointer : Trivial {
449+
typealias T = Int
450+
}
451+
452+
extension UnsafeMutableRawPointer : Trivial {
453+
typealias T = Int
454+
}
455+
444456
func test_inference_through_implicit_conversion() {
445-
class C {}
457+
struct C : Hashable {}
446458

447459
func test<T: Trivial>(_: T) -> T {}
448460

461+
var arr: [C] = []
462+
let ptr: UnsafeMutablePointer<C> = UnsafeMutablePointer(bitPattern: 0)!
463+
let rawPtr: UnsafeMutableRawPointer = UnsafeMutableRawPointer(bitPattern: 0)!
464+
449465
let _: C? = test(C()) // Ok -> argument is implicitly promoted into an optional
450466
let _: UnsafePointer<C> = test([C()]) // Ok - argument is implicitly converted to a pointer
467+
let _: UnsafeRawPointer = test([C()]) // Ok - argument is implicitly converted to a raw pointer
468+
let _: UnsafeMutableRawPointer = test(&arr) // Ok - inout Array<T> -> UnsafeMutableRawPointer
469+
let _: UnsafePointer<C> = test(ptr) // Ok - UnsafeMutablePointer<T> -> UnsafePointer<T>
470+
let _: UnsafeRawPointer = test(ptr) // Ok - UnsafeMutablePointer<T> -> UnsafeRawPointer
471+
let _: UnsafeRawPointer = test(rawPtr) // Ok - UnsafeMutableRawPointer -> UnsafeRawPointer
472+
let _: UnsafeMutableRawPointer = test(ptr) // Ok - UnsafeMutablePointer<T> -> UnsafeMutableRawPointer
473+
let _: AnyHashable = test(C()) // Ok - argument is implicitly converted to `AnyHashable` because it's Hashable
474+
}
475+
476+
// Make sure that conformances transitively checked through implicit conversions work with conditional requirements
477+
protocol TestCond {}
478+
479+
extension Optional : TestCond where Wrapped == Int? {}
480+
481+
func simple<T : TestCond>(_ x: T) -> T { x }
482+
483+
func overloaded<T: TestCond>(_ x: T) -> T { x }
484+
func overloaded<T: TestCond>(_ x: String) -> T { fatalError() }
485+
486+
func overloaded_result() -> Int { 42 }
487+
func overloaded_result() -> String { "" }
488+
489+
func test_arg_conformance_with_conditional_reqs(i: Int) {
490+
let _: Int?? = simple(i)
491+
let _: Int?? = overloaded(i)
492+
let _: Int?? = simple(overloaded_result())
493+
let _: Int?? = overloaded(overloaded_result())
451494
}

0 commit comments

Comments
 (0)