Skip to content

Commit 78996db

Browse files
authored
Merge pull request #64476 from rjmccall/arity-reabstraction
Implement parameter arity reabstraction
2 parents 048bcc1 + a05fef5 commit 78996db

12 files changed

+2093
-815
lines changed

include/swift/SIL/AbstractionPattern.h

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace clang {
3535

3636
namespace swift {
3737
namespace Lowering {
38+
class FunctionParamGenerator;
3839

3940
/// A pattern for the abstraction of a value.
4041
///
@@ -1501,29 +1502,20 @@ class AbstractionPattern {
15011502
/// parameters in the pattern.
15021503
unsigned getNumFunctionParams() const;
15031504

1504-
/// Perform a parallel visitation of the parameters of a function.
1505-
///
1506-
/// If this is a function pattern, calls handleScalar or
1507-
/// handleExpansion as appropriate for each parameter of the
1508-
/// original function, in order.
1505+
/// Traverses the parameters of a function, where this is the
1506+
/// abstraction pattern for the function (its "original type")
1507+
/// and the given parameters are the substituted formal parameters.
1508+
/// Calls the callback once for each parameter in the abstraction
1509+
/// pattern.
15091510
///
15101511
/// If this is not a function pattern, calls handleScalar for each
1511-
/// parameter of the substituted function type. Functions with
1512-
/// pack expansions cannot be abstracted legally this way.
1512+
/// parameter of the substituted function type. Note that functions
1513+
/// with pack expansions cannot be legally abstracted this way; it
1514+
/// is not possible in Swift's ABI to support this without some sort
1515+
/// of dynamic argument-forwarding thunk.
15131516
void forEachFunctionParam(AnyFunctionType::CanParamArrayRef substParams,
15141517
bool ignoreFinalParam,
1515-
llvm::function_ref<void(unsigned origParamIndex,
1516-
unsigned substParamIndex,
1517-
ParameterTypeFlags origFlags,
1518-
AbstractionPattern origParamType,
1519-
AnyFunctionType::CanParam substParam)>
1520-
handleScalar,
1521-
llvm::function_ref<void(unsigned origParamIndex,
1522-
unsigned substParamIndex,
1523-
ParameterTypeFlags origFlags,
1524-
AbstractionPattern origExpansionType,
1525-
AnyFunctionType::CanParamArrayRef substParams)>
1526-
handleExpansion) const;
1518+
llvm::function_ref<void(FunctionParamGenerator &param)> function) const;
15271519

15281520
/// Given that the value being abstracted is optional, return the
15291521
/// abstraction pattern for its object type.
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===--- AbstractionPatternGenerators.h -------------------------*- C++ -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines "generators" that can be used with an AbstractionPattern
14+
// to do certain kinds of traversal without using callbacks.
15+
// This can be useful when a traversal is required in parallel with
16+
// some other traversal.
17+
//
18+
//===----------------------------------------------------------------------===//
19+
20+
#ifndef SWIFT_SIL_ABSTRACTIONPATTERNGENERATORS_H
21+
#define SWIFT_SIL_ABSTRACTIONPATTERNGENERATORS_H
22+
23+
#include "swift/SIL/AbstractionPattern.h"
24+
25+
namespace swift {
26+
namespace Lowering {
27+
28+
/// A generator for traversing the formal function parameters of a type
29+
/// while properly respecting variadic generics.
30+
class FunctionParamGenerator {
31+
// The steady state of the generator.
32+
33+
/// The abstraction pattern of the entire function type. Set once
34+
/// during construction.
35+
AbstractionPattern origFunctionType;
36+
37+
/// The list of all substituted parameters to traverse. Set once
38+
/// during construction.
39+
AnyFunctionType::CanParamArrayRef allSubstParams;
40+
41+
/// The number of orig parameters to traverse. Set once during
42+
/// construction.
43+
unsigned numOrigParams;
44+
45+
/// The index of the current orig parameter.
46+
/// Incremented during advance().
47+
unsigned origParamIndex = 0;
48+
49+
/// The (start) index of the current subst parameters.
50+
/// Incremented during advance().
51+
unsigned substParamIndex = 0;
52+
53+
/// The number of subst parameters corresponding to the current
54+
/// subst parameter.
55+
unsigned numSubstParamsForOrigParam;
56+
57+
/// Whether the orig function type is opaque, i.e. does not permit us to
58+
/// call getNumFunctionParams() and similar accessors. Set once during
59+
/// construction.
60+
bool origFunctionTypeIsOpaque;
61+
62+
/// Whether the current orig parameter is a pack expansion.
63+
bool origParamIsExpansion;
64+
65+
/// The abstraction pattern of the current orig parameter.
66+
/// If it is a pack expansion, this is the expansion type, not the
67+
/// pattern type.
68+
AbstractionPattern origParamType = AbstractionPattern::getInvalid();
69+
70+
/// Load the informaton for the current orig parameter into the
71+
/// fields above for it.
72+
void loadParameter() {
73+
origParamType = origFunctionType.getFunctionParamType(origParamIndex);
74+
origParamIsExpansion = origParamType.isPackExpansion();
75+
numSubstParamsForOrigParam =
76+
(origParamIsExpansion
77+
? origParamType.getNumPackExpandedComponents()
78+
: 1);
79+
}
80+
81+
public:
82+
FunctionParamGenerator(AbstractionPattern origFunctionType,
83+
AnyFunctionType::CanParamArrayRef substParams,
84+
bool ignoreFinalOrigParam);
85+
86+
/// Is the traversal finished? If so, none of the getters below
87+
/// are allowed to be called.
88+
bool isFinished() const {
89+
return origParamIndex == numOrigParams;
90+
}
91+
92+
/// Advance to the next orig parameter.
93+
void advance() {
94+
assert(!isFinished());
95+
origParamIndex++;
96+
substParamIndex += numSubstParamsForOrigParam;
97+
if (!isFinished()) loadParameter();
98+
}
99+
100+
/// Return the index of the current orig parameter.
101+
unsigned getOrigIndex() const {
102+
assert(!isFinished());
103+
return origParamIndex;
104+
}
105+
106+
/// Return the index of the (first) subst parameter corresponding
107+
/// to the current orig parameter.
108+
unsigned getSubstIndex() const {
109+
assert(!isFinished());
110+
return origParamIndex;
111+
}
112+
113+
/// Return the parameter flags for the current orig parameter.
114+
ParameterTypeFlags getOrigFlags() const {
115+
assert(!isFinished());
116+
return (origFunctionTypeIsOpaque
117+
? allSubstParams[substParamIndex].getParameterFlags()
118+
: origFunctionType.getFunctionParamFlags(origParamIndex));
119+
}
120+
121+
/// Return the type of the current orig parameter.
122+
const AbstractionPattern &getOrigType() const {
123+
assert(!isFinished());
124+
return origParamType;
125+
}
126+
127+
/// Return whether the current orig parameter type is a pack expansion.
128+
bool isPackExpansion() const {
129+
assert(!isFinished());
130+
return origParamIsExpansion;
131+
}
132+
133+
/// Return the substituted parameters corresponding to the current
134+
/// orig parameter type. If the current orig parameter is not a
135+
/// pack expansion, this will have exactly one element.
136+
AnyFunctionType::CanParamArrayRef getSubstParams() const {
137+
assert(!isFinished());
138+
return allSubstParams.slice(substParamIndex, numSubstParamsForOrigParam);
139+
}
140+
141+
/// Call this to finalize the traversal and assert that it was done
142+
/// properly.
143+
void finish() {
144+
assert(isFinished() && "didn't finish the traversal");
145+
assert(substParamIndex == allSubstParams.size() &&
146+
"didn't exhaust subst parameters; possible missing subs on "
147+
"orig function type");
148+
}
149+
};
150+
151+
} // end namespace Lowering
152+
} // end namespace swift
153+
154+
#endif

include/swift/SIL/SILType.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,13 @@ class SILType {
537537
return SILType(castTo<TupleType>().getElementType(index), getCategory());
538538
}
539539

540+
/// Given that this is a pack type, return the lowered type of the
541+
/// given pack element. The result will have the same value
542+
/// category as the base type.
543+
SILType getPackElementType(unsigned index) const {
544+
return SILType(castTo<SILPackType>()->getElementType(index), getCategory());
545+
}
546+
540547
/// Given that this is a pack expansion type, return the lowered type
541548
/// of the pattern type. The result will have the same value category
542549
/// as the base type.

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/TypeCheckRequests.h"
2727
#include "swift/AST/CanTypeVisitor.h"
2828
#include "swift/SIL/TypeLowering.h"
29+
#include "swift/SIL/AbstractionPatternGenerators.h"
2930
#include "clang/AST/ASTContext.h"
3031
#include "clang/AST/Attr.h"
3132
#include "clang/AST/DeclCXX.h"
@@ -1202,55 +1203,33 @@ unsigned AbstractionPattern::getNumFunctionParams() const {
12021203

12031204
void AbstractionPattern::
12041205
forEachFunctionParam(AnyFunctionType::CanParamArrayRef substParams,
1205-
bool ignoreFinalParam,
1206-
llvm::function_ref<void(unsigned origParamIndex,
1207-
unsigned substParamIndex,
1208-
ParameterTypeFlags origFlags,
1209-
AbstractionPattern origParamType,
1210-
AnyFunctionType::CanParam substParam)>
1211-
handleScalar,
1212-
llvm::function_ref<void(unsigned origParamIndex,
1213-
unsigned substParamIndex,
1214-
ParameterTypeFlags origFlags,
1215-
AbstractionPattern origExpansionType,
1216-
AnyFunctionType::CanParamArrayRef substParams)>
1217-
handleExpansion) const {
1218-
// Honor ignoreFinalParam for the substituted parameters on all paths.
1219-
if (ignoreFinalParam) substParams = substParams.drop_back();
1220-
1221-
// If we don't have a function type, use the substituted type.
1222-
if (isTypeParameterOrOpaqueArchetype() ||
1223-
getKind() == Kind::OpaqueFunction ||
1224-
getKind() == Kind::OpaqueDerivativeFunction) {
1225-
for (auto substParamIndex : indices(substParams)) {
1226-
handleScalar(substParamIndex, substParamIndex,
1227-
substParams[substParamIndex].getParameterFlags(),
1228-
AbstractionPattern::getOpaque(),
1229-
substParams[substParamIndex]);
1230-
}
1231-
return;
1206+
bool ignoreFinalOrigParam,
1207+
llvm::function_ref<void(FunctionParamGenerator &param)> function) const {
1208+
FunctionParamGenerator generator(*this, substParams, ignoreFinalOrigParam);
1209+
for (; !generator.isFinished(); generator.advance()) {
1210+
function(generator);
1211+
}
1212+
generator.finish();
1213+
}
1214+
1215+
FunctionParamGenerator::FunctionParamGenerator(
1216+
AbstractionPattern origFunctionType,
1217+
AnyFunctionType::CanParamArrayRef substParams,
1218+
bool ignoreFinalOrigParam)
1219+
: origFunctionType(origFunctionType), allSubstParams(substParams) {
1220+
origFunctionTypeIsOpaque =
1221+
(origFunctionType.isTypeParameterOrOpaqueArchetype() ||
1222+
origFunctionType.isOpaqueFunctionOrOpaqueDerivativeFunction());
1223+
1224+
if (origFunctionTypeIsOpaque) {
1225+
numOrigParams = allSubstParams.size();
1226+
} else {
1227+
numOrigParams = origFunctionType.getNumFunctionParams();
1228+
if (ignoreFinalOrigParam)
1229+
numOrigParams--;
12321230
}
12331231

1234-
size_t numOrigParams = getNumFunctionParams();
1235-
if (ignoreFinalParam) numOrigParams--;
1236-
1237-
size_t substParamIndex = 0;
1238-
for (auto origParamIndex : range(numOrigParams)) {
1239-
auto origParamType = getFunctionParamType(origParamIndex);
1240-
if (origParamType.isPackExpansion()) {
1241-
unsigned numComponents = origParamType.getNumPackExpandedComponents();
1242-
handleExpansion(origParamIndex, substParamIndex,
1243-
getFunctionParamFlags(origParamIndex), origParamType,
1244-
substParams.slice(substParamIndex, numComponents));
1245-
substParamIndex += numComponents;
1246-
} else {
1247-
handleScalar(origParamIndex, substParamIndex,
1248-
getFunctionParamFlags(origParamIndex), origParamType,
1249-
substParams[substParamIndex]);
1250-
substParamIndex++;
1251-
}
1252-
}
1253-
assert(substParamIndex == substParams.size());
1232+
if (!isFinished()) loadParameter();
12541233
}
12551234

12561235
static CanType getOptionalObjectType(CanType type) {
@@ -2239,21 +2218,20 @@ class SubstFunctionTypePatternVisitor
22392218
};
22402219

22412220
pattern.forEachFunctionParam(func.getParams(), /*ignore self*/ false,
2242-
[&](unsigned origParamIndex, unsigned substParamIndex,
2243-
ParameterTypeFlags origFlags, AbstractionPattern origParamType,
2244-
AnyFunctionType::CanParam substParam) {
2245-
auto newParamTy = visit(substParam.getParameterType(), origParamType);
2246-
addParam(origFlags, newParamTy);
2247-
}, [&](unsigned origParamIndex, unsigned substParamIndex,
2248-
ParameterTypeFlags origFlags,
2249-
AbstractionPattern origExpansionType,
2250-
AnyFunctionType::CanParamArrayRef substParams) {
2251-
CanType candidateSubstType;
2252-
if (!substParams.empty())
2253-
candidateSubstType = substParams[0].getParameterType();
2254-
auto expansionType =
2255-
handlePackExpansion(origExpansionType, candidateSubstType);
2256-
addParam(origFlags, expansionType);
2221+
[&](FunctionParamGenerator &param) {
2222+
if (!param.isPackExpansion()) {
2223+
auto newParamTy = visit(param.getSubstParams()[0].getParameterType(),
2224+
param.getOrigType());
2225+
addParam(param.getOrigFlags(), newParamTy);
2226+
} else {
2227+
auto substParams = param.getSubstParams();
2228+
CanType candidateSubstType;
2229+
if (!substParams.empty())
2230+
candidateSubstType = substParams[0].getParameterType();
2231+
auto expansionType =
2232+
handlePackExpansion(param.getOrigType(), candidateSubstType);
2233+
addParam(param.getOrigFlags(), expansionType);
2234+
}
22572235
});
22582236

22592237
if (yieldType) {

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "swift/ClangImporter/ClangImporter.h"
3333
#include "swift/SIL/SILModule.h"
3434
#include "swift/SIL/SILType.h"
35+
#include "swift/SIL/AbstractionPatternGenerators.h"
3536
#include "clang/AST/ASTContext.h"
3637
#include "clang/AST/Attr.h"
3738
#include "clang/AST/DeclCXX.h"
@@ -1569,22 +1570,22 @@ class DestructureInputs {
15691570
maybeAddForeignParameters();
15701571

15711572
// Process all the non-self parameters.
1572-
origType.forEachFunctionParam(params, hasSelf,
1573-
[&](unsigned origParamIndex, unsigned substParamIndex,
1574-
ParameterTypeFlags origFlags,
1575-
AbstractionPattern origParamType,
1576-
AnyFunctionType::CanParam substParam) {
1573+
origType.forEachFunctionParam(params.drop_back(hasSelf ? 1 : 0),
1574+
/*ignore final orig param*/ hasSelf,
1575+
[&](FunctionParamGenerator &param) {
15771576
// If the parameter is not a pack expansion, just pull off the
15781577
// next parameter and destructure it in parallel with the abstraction
15791578
// pattern for the type.
1580-
visit(origParamType, substParam, /*forSelf*/false);
1581-
}, [&](unsigned origParamIndex, unsigned substParamIndex,
1582-
ParameterTypeFlags origFlags,
1583-
AbstractionPattern origExpansionType,
1584-
AnyFunctionType::CanParamArrayRef substParams) {
1579+
if (!param.isPackExpansion()) {
1580+
visit(param.getOrigType(), param.getSubstParams()[0],
1581+
/*forSelf*/false);
1582+
return;
1583+
}
1584+
15851585
// Otherwise, collect the substituted components into a pack.
1586+
auto origExpansionType = param.getOrigType();
15861587
SmallVector<CanType, 8> packElts;
1587-
for (auto substParam : substParams) {
1588+
for (auto substParam : param.getSubstParams()) {
15881589
auto substParamType = substParam.getParameterType();
15891590
auto origParamType =
15901591
origExpansionType.getPackExpansionComponentType(substParamType);
@@ -1597,6 +1598,7 @@ class DestructureInputs {
15971598
SILPackType::ExtInfo extInfo(/*address*/ indirect);
15981599
auto packTy = SILPackType::get(TC.Context, extInfo, packElts);
15991600

1601+
auto origFlags = param.getOrigFlags();
16001602
addPackParameter(packTy, origFlags.getValueOwnership(),
16011603
origFlags.isNoDerivative());
16021604
});

0 commit comments

Comments
 (0)