Skip to content

Commit 594044a

Browse files
committed
Sema: Clean up handling of protocol operators with concrete operands
In this case we would "devirtualize" the protocol requirement call by building the AST to model a direct reference to the witness. Previously this was done by recursively calling typeCheckExpression(), but the only thing this did was recover the correct substitutions for the call. Instead, we can just build the right SubstitutionMap directly. Unfortunately, while we serialize enough information in the AST to devirtualize calls at the SIL level, we do not for AST Exprs. This is because SIL devirtualization builds a reference to the witness thunk signature, which is an intermediate step between the protocol requirement and the witness. I get around this by deriving the substitutions from walking in parallel over the interface type of the witness, together with the inferred type of the call expression.
1 parent 27dad91 commit 594044a

File tree

2 files changed

+156
-38
lines changed

2 files changed

+156
-38
lines changed

lib/Sema/CSApply.cpp

Lines changed: 96 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/ASTWalker.h"
2727
#include "swift/AST/ExistentialLayout.h"
2828
#include "swift/AST/Initializer.h"
29+
#include "swift/AST/GenericEnvironment.h"
2930
#include "swift/AST/GenericSignature.h"
3031
#include "swift/AST/ParameterList.h"
3132
#include "swift/AST/ProtocolConformance.h"
@@ -378,6 +379,66 @@ namespace {
378379
return base.getOldType();
379380
}
380381

382+
// Returns None if the AST does not contain enough information to recover
383+
// substitutions; this is different from an Optional(SubstitutionMap()),
384+
// indicating a valid call to a non-generic operator.
385+
Optional<SubstitutionMap>
386+
getOperatorSubstitutions(ValueDecl *witness, Type refType) {
387+
// We have to recover substitutions in this hacky way because
388+
// the AST does not retain enough information to devirtualize
389+
// calls like this.
390+
auto witnessType = witness->getInterfaceType();
391+
392+
// Compute the substitutions.
393+
auto *gft = witnessType->getAs<GenericFunctionType>();
394+
if (gft == nullptr) {
395+
if (refType->isEqual(witnessType))
396+
return SubstitutionMap();
397+
return None;
398+
}
399+
400+
auto sig = gft->getGenericSignature();
401+
auto *env = sig->getGenericEnvironment();
402+
403+
witnessType = FunctionType::get(gft->getParams(),
404+
gft->getResult(),
405+
gft->getExtInfo());
406+
witnessType = env->mapTypeIntoContext(witnessType);
407+
408+
TypeSubstitutionMap subs;
409+
auto substType = witnessType->substituteBindingsTo(
410+
refType,
411+
[&](ArchetypeType *origType, CanType substType) -> CanType {
412+
if (auto gpType = dyn_cast<GenericTypeParamType>(
413+
origType->getInterfaceType()->getCanonicalType()))
414+
subs[gpType] = substType;
415+
416+
return substType;
417+
});
418+
419+
// If substitution failed, it means that the protocol requirement type
420+
// and the witness type did not match up. The only time that this
421+
// should happen is when the witness is defined in a base class and
422+
// the actual call uses a derived class. For example,
423+
//
424+
// protocol P { func +(lhs: Self, rhs: Self) }
425+
// class Base : P { func +(lhs: Base, rhs: Base) {} }
426+
// class Derived : Base {}
427+
//
428+
// If we enter this code path with two operands of type Derived,
429+
// we know we're calling the protocol requirement P.+, with a
430+
// substituted type of (Derived, Derived) -> (). But the type of
431+
// the witness is (Base, Base) -> (). Just bail out and make a
432+
// witness method call in this rare case; SIL mandatory optimizations
433+
// will likely devirtualize it anyway.
434+
if (!substType)
435+
return None;
436+
437+
return SubstitutionMap::get(sig,
438+
QueryTypeSubstitutionMap{subs},
439+
TypeChecker::LookUpConformance(cs.DC));
440+
}
441+
381442
public:
382443
/// Build a reference to the given declaration.
383444
Expr *buildDeclRef(SelectedOverload overload, DeclNameLoc loc,
@@ -400,56 +461,53 @@ namespace {
400461

401462
// Handle operator requirements found in protocols.
402463
if (auto proto = dyn_cast<ProtocolDecl>(decl->getDeclContext())) {
403-
// If we don't have an archetype or existential, we have to call the
404-
// witness.
464+
// If we have a concrete conformance, build a call to the witness.
465+
//
405466
// FIXME: This is awful. We should be able to handle this as a call to
406467
// the protocol requirement with Self == the concrete type, and SILGen
407468
// (or later) can devirtualize as appropriate.
408-
if (!baseTy->is<ArchetypeType>() && !baseTy->isAnyExistentialType()) {
409-
auto conformance =
410-
TypeChecker::conformsToProtocol(
411-
baseTy, proto, cs.DC,
412-
ConformanceCheckFlags::InExpression);
413-
if (conformance.isConcrete()) {
414-
if (auto witness =
415-
conformance.getConcrete()->getWitnessDecl(decl)) {
416-
// Hack up an AST that we can type-check (independently) to get
417-
// it into the right form.
418-
// FIXME: the hop through 'getDecl()' is because
419-
// SpecializedProtocolConformance doesn't substitute into
420-
// witnesses' ConcreteDeclRefs.
421-
Type expectedFnType = simplifyType(overload.openedType);
422-
assert(expectedFnType->isEqual(
423-
fullType->castTo<AnyFunctionType>()->getResult()) &&
424-
"Cannot handle adjustments made to the opened type");
469+
auto conformance =
470+
TypeChecker::conformsToProtocol(
471+
baseTy, proto, cs.DC,
472+
ConformanceCheckFlags::InExpression);
473+
if (conformance.isConcrete()) {
474+
if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) {
475+
// The fullType was computed by substituting the protocol
476+
// requirement so it always has a (Self) -> ... curried
477+
// application. Strip it off if the witness was a top-level
478+
// function.
479+
Type refType;
480+
if (witness->getDeclContext()->isTypeContext())
481+
refType = fullType;
482+
else
483+
refType = fullType->castTo<AnyFunctionType>()->getResult();
484+
485+
// Build the AST for the call to the witness.
486+
auto subMap = getOperatorSubstitutions(witness, refType);
487+
if (subMap) {
488+
ConcreteDeclRef witnessRef(witness, *subMap);
489+
auto declRefExpr = new (ctx) DeclRefExpr(witnessRef, loc,
490+
/*Implicit=*/false);
491+
declRefExpr->setFunctionRefKind(choice.getFunctionRefKind());
492+
cs.setType(declRefExpr, refType);
493+
425494
Expr *refExpr;
426495
if (witness->getDeclContext()->isTypeContext()) {
496+
// If the operator is a type member, add the implicit
497+
// (Self) -> ... call.
427498
Expr *base =
428499
TypeExpr::createImplicitHack(loc.getBaseNameLoc(), baseTy,
429500
ctx);
430-
refExpr = new (ctx) MemberRefExpr(base, SourceLoc(), witness,
431-
loc, /*Implicit=*/true);
501+
cs.setType(base, MetatypeType::get(baseTy));
502+
503+
refExpr = new (ctx) DotSyntaxCallExpr(declRefExpr,
504+
SourceLoc(), base);
505+
auto refType = fullType->castTo<FunctionType>()->getResult();
506+
cs.setType(refExpr, refType);
432507
} else {
433-
auto declRefExpr = new (ctx) DeclRefExpr(witness, loc,
434-
/*Implicit=*/false);
435-
declRefExpr->setFunctionRefKind(choice.getFunctionRefKind());
436508
refExpr = declRefExpr;
437509
}
438510

439-
auto resultTy = TypeChecker::typeCheckExpression(
440-
refExpr, cs.DC, TypeLoc::withoutLoc(expectedFnType),
441-
CTP_CannotFail);
442-
if (!resultTy)
443-
return nullptr;
444-
445-
cs.cacheExprTypes(refExpr);
446-
447-
// Remove an outer function-conversion expression. This
448-
// happens when we end up referring to a witness for a
449-
// superclass conformance, and 'Self' differs.
450-
if (auto fnConv = dyn_cast<FunctionConversionExpr>(refExpr))
451-
refExpr = fnConv->getSubExpr();
452-
453511
return forceUnwrapIfExpected(refExpr, choice, locator);
454512
}
455513
}

test/SILGen/protocol_operators.swift

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
2+
3+
infix operator +++
4+
5+
protocol Twig {
6+
static func +++(lhs: Self, rhs: Self)
7+
}
8+
9+
struct Branch : Twig {
10+
@_implements(Twig, +++(_:_:))
11+
static func doIt(_: Branch, _: Branch) {}
12+
}
13+
14+
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators9useBranchyyAA0D0VF : $@convention(thin) (Branch) -> () {
15+
// CHECK: function_ref @$s18protocol_operators6BranchV4doItyyAC_ACtFZ : $@convention(method) (Branch, Branch, @thin Branch.Type) -> ()
16+
// CHECK: return
17+
func useBranch(_ b: Branch) {
18+
b +++ b
19+
}
20+
21+
class Stick : Twig {
22+
static func +++(lhs: Stick, rhs: Stick) {}
23+
}
24+
25+
class Stuck : Stick, ExpressibleByIntegerLiteral {
26+
typealias IntegerLiteralType = Int
27+
28+
required init(integerLiteral: Int) {}
29+
}
30+
31+
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators8useStickyyAA5StuckC_AA0D0CtF : $@convention(thin) (@guaranteed Stuck, @guaranteed Stick) -> () {
32+
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
33+
// CHECK: function_ref @$s18protocol_operators5StickC3pppoiyyAC_ACtFZ : $@convention(method) (@guaranteed Stick, @guaranteed Stick, @thick Stick.Type) -> ()
34+
// CHECK: witness_method $Stuck, #Twig."+++"!1 : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
35+
// CHECK: return
36+
func useStick(_ a: Stuck, _ b: Stick) {
37+
_ = a +++ b
38+
_ = b +++ b
39+
_ = a +++ 5
40+
}
41+
42+
class Twine<X> : Twig {
43+
static func +++(lhs: Twine, rhs: Twine) {}
44+
}
45+
46+
class Rope : Twine<Int>, ExpressibleByIntegerLiteral {
47+
typealias IntegerLiteralType = Int
48+
49+
required init(integerLiteral: Int) {}
50+
}
51+
52+
// CHECK-LABEL: sil hidden [ossa] @$s18protocol_operators7useRopeyyAA0D0C_ADtF : $@convention(thin) (@guaranteed Rope, @guaranteed Rope) -> () {
53+
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
54+
// CHECK: function_ref @$s18protocol_operators5TwineC3pppoiyyACyxG_AEtFZ : $@convention(method) <τ_0_0> (@guaranteed Twine<τ_0_0>, @guaranteed Twine<τ_0_0>, @thick Twine<τ_0_0>.Type) -> ()
55+
// CHECK: witness_method $Rope, #Twig."+++"!1 : <Self where Self : Twig> (Self.Type) -> (Self, Self) -> () : $@convention(witness_method: Twig) <τ_0_0 where τ_0_0 : Twig> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
56+
func useRope(_ r: Rope, _ s: Rope) {
57+
_ = r +++ s
58+
_ = s +++ s
59+
_ = r +++ 5
60+
}

0 commit comments

Comments
 (0)