Skip to content

Commit 157be34

Browse files
committed
Implement the callee side of returning a tuple containing a pack expansion.
This required quite a bit of infrastructure for emitting this kind of tuple expression, although I'm not going to claim they really work yet; in particular, I know the RValue constructor is going to try to explode them, which it really shouldn't. It also doesn't include the caller side of returns, for which I'll need to teach ResultPlan to do the new abstraction-pattern walk. But that's next.
1 parent 2276a8d commit 157be34

19 files changed

+979
-74
lines changed

include/swift/AST/Types.h

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,20 @@ BEGIN_CAN_TYPE_WRAPPER(TupleType, Type)
24512451
CanTupleEltTypeArrayRef getElementTypes() const {
24522452
return CanTupleEltTypeArrayRef(getPointer()->getElements());
24532453
}
2454+
2455+
bool containsPackExpansionType() const {
2456+
return containsPackExpansionTypeImpl(*this);
2457+
}
2458+
2459+
/// Induce a pack type from a range of the elements of this tuple type.
2460+
inline CanTypeWrapper<PackType>
2461+
getInducedPackType(unsigned start, unsigned count) const;
2462+
2463+
private:
2464+
static bool containsPackExpansionTypeImpl(CanTupleType tuple);
2465+
2466+
static CanTypeWrapper<PackType>
2467+
getInducedPackTypeImpl(CanTupleType tuple, unsigned start, unsigned count);
24542468
END_CAN_TYPE_WRAPPER(TupleType, Type)
24552469

24562470
/// UnboundGenericType - Represents a generic type where the type arguments have
@@ -4233,6 +4247,11 @@ class SILResultInfo {
42334247
return !isIndirectFormalResult(getConvention());
42344248
}
42354249

4250+
/// Is this a pack result? Pack results are always indirect.
4251+
bool isPack() const {
4252+
return getConvention() == ResultConvention::Pack;
4253+
}
4254+
42364255
/// Transform this SILResultInfo by applying the user-provided
42374256
/// function to its type.
42384257
///
@@ -4408,8 +4427,9 @@ class SILFunctionType final
44084427

44094428
// These are *normal* results if this is not a coroutine and *yield* results
44104429
// otherwise.
4411-
unsigned NumAnyResults : 16; // Not including the ErrorResult.
4412-
unsigned NumAnyIndirectFormalResults : 16; // Subset of NumAnyResults.
4430+
unsigned NumAnyResults; // Not including the ErrorResult.
4431+
unsigned NumAnyIndirectFormalResults; // Subset of NumAnyResults.
4432+
unsigned NumPackResults; // Subset of NumAnyIndirectFormalResults.
44134433

44144434
// [NOTE: SILFunctionType-layout]
44154435
// The layout of a SILFunctionType in memory is:
@@ -4589,6 +4609,9 @@ class SILFunctionType final
45894609
unsigned getNumDirectFormalResults() const {
45904610
return isCoroutine() ? 0 : NumAnyResults - NumAnyIndirectFormalResults;
45914611
}
4612+
unsigned getNumPackResults() const {
4613+
return isCoroutine() ? 0 : NumPackResults;
4614+
}
45924615

45934616
struct IndirectFormalResultFilter {
45944617
bool operator()(SILResultInfo result) const {
@@ -4618,6 +4641,21 @@ class SILFunctionType final
46184641
return llvm::make_filter_range(getResults(), DirectFormalResultFilter());
46194642
}
46204643

4644+
struct PackResultFilter {
4645+
bool operator()(SILResultInfo result) const {
4646+
return result.isPack();
4647+
}
4648+
};
4649+
using PackResultIter =
4650+
llvm::filter_iterator<const SILResultInfo *, PackResultFilter>;
4651+
using PackResultRange = iterator_range<PackResultIter>;
4652+
4653+
/// A range of SILResultInfo for all pack results. Pack results are also
4654+
/// included in the set of indirect results.
4655+
PackResultRange getPackResults() const {
4656+
return llvm::make_filter_range(getResults(), PackResultFilter());
4657+
}
4658+
46214659
/// Get a single non-address SILType that represents all formal direct
46224660
/// results. The actual SIL result type of an apply instruction that calls
46234661
/// this function depends on the current SIL stage and is known by
@@ -4636,6 +4674,9 @@ class SILFunctionType final
46364674
unsigned getNumDirectFormalYields() const {
46374675
return isCoroutine() ? NumAnyResults - NumAnyIndirectFormalResults : 0;
46384676
}
4677+
unsigned getNumPackYields() const {
4678+
return isCoroutine() ? NumPackResults : 0;
4679+
}
46394680

46404681
struct IndirectFormalYieldFilter {
46414682
bool operator()(SILYieldInfo yield) const {
@@ -6798,6 +6839,11 @@ BEGIN_CAN_TYPE_WRAPPER(PackType, Type)
67986839
}
67996840
END_CAN_TYPE_WRAPPER(PackType, Type)
68006841

6842+
inline CanPackType
6843+
CanTupleType::getInducedPackType(unsigned start, unsigned end) const {
6844+
return getInducedPackTypeImpl(*this, start, end);
6845+
}
6846+
68016847
/// PackExpansionType - The interface type of the explicit expansion of a
68026848
/// corresponding set of variadic generic parameters.
68036849
///

include/swift/SIL/SILFunction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,10 @@ class SILFunction
725725

726726
SILType getLoweredType(Type t) const;
727727

728+
CanType getLoweredRValueType(Lowering::AbstractionPattern orig, Type subst) const;
729+
730+
CanType getLoweredRValueType(Type t) const;
731+
728732
SILType getLoweredLoadableType(Type t) const;
729733

730734
SILType getLoweredType(SILType t) const;

include/swift/SIL/SILFunctionConventions.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -214,30 +214,38 @@ class SILFunctionConventions {
214214

215215
/// Get the number of SIL results passed as address-typed arguments.
216216
unsigned getNumIndirectSILResults() const {
217-
return silConv.loweredAddresses ? funcTy->getNumIndirectFormalResults() : 0;
217+
// TODO: Return packs directly in lowered-address mode
218+
return silConv.loweredAddresses ? funcTy->getNumIndirectFormalResults()
219+
: funcTy->getNumPackResults();
218220
}
219221

220222
/// Are any SIL results passed as address-typed arguments?
221223
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }
222224

223-
using IndirectSILResultIter = SILFunctionType::IndirectFormalResultIter;
224-
using IndirectSILResultRange = SILFunctionType::IndirectFormalResultRange;
225+
struct IndirectSILResultFilter {
226+
bool loweredAddresses;
227+
IndirectSILResultFilter(bool loweredAddresses)
228+
: loweredAddresses(loweredAddresses) {}
229+
bool operator()(SILResultInfo result) const {
230+
return (loweredAddresses ? result.isFormalIndirect() : result.isPack());
231+
}
232+
};
233+
using IndirectSILResultIter =
234+
llvm::filter_iterator<const SILResultInfo *, IndirectSILResultFilter>;
235+
using IndirectSILResultRange = iterator_range<IndirectSILResultIter>;
225236

226237
/// Return a range of indirect result information for results passed as
227238
/// address-typed SIL arguments.
228239
IndirectSILResultRange getIndirectSILResults() const {
229-
if (silConv.loweredAddresses)
230-
return funcTy->getIndirectFormalResults();
231-
232240
return llvm::make_filter_range(
233-
llvm::make_range((const SILResultInfo *)0, (const SILResultInfo *)0),
234-
SILFunctionType::IndirectFormalResultFilter());
241+
funcTy->getResults(),
242+
IndirectSILResultFilter(silConv.loweredAddresses));
235243
}
236244

237245
struct SILResultTypeFunc;
238246

239247
// Gratuitous template parameter is to delay instantiating `mapped_iterator`
240-
// on the incomplete type SILParameterTypeFunc.
248+
// on the incomplete type SILResultTypeFunc.
241249
template<bool _ = false>
242250
using IndirectSILResultTypeIter = typename delay_template_expansion<_,
243251
llvm::mapped_iterator, IndirectSILResultIter, SILResultTypeFunc>::type;
@@ -253,7 +261,7 @@ class SILFunctionConventions {
253261
/// Get the number of SIL results directly returned by SIL value.
254262
unsigned getNumDirectSILResults() const {
255263
return silConv.loweredAddresses ? funcTy->getNumDirectFormalResults()
256-
: funcTy->getNumResults();
264+
: funcTy->getNumResults() - funcTy->getNumPackResults();
257265
}
258266

259267
/// Like getNumDirectSILResults but @out tuples, which are not flattened in
@@ -266,7 +274,7 @@ class SILFunctionConventions {
266274
DirectSILResultFilter(bool loweredAddresses)
267275
: loweredAddresses(loweredAddresses) {}
268276
bool operator()(SILResultInfo result) const {
269-
return !(loweredAddresses && result.isFormalIndirect());
277+
return (loweredAddresses ? !result.isFormalIndirect() : !result.isPack());
270278
}
271279
};
272280
using DirectSILResultIter =

lib/AST/ASTContext.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4342,20 +4342,27 @@ SILFunctionType::SILFunctionType(
43424342
if (coroutineKind == SILCoroutineKind::None) {
43434343
assert(yields.empty());
43444344
NumAnyResults = normalResults.size();
4345-
NumAnyIndirectFormalResults =
4346-
std::count_if(normalResults.begin(), normalResults.end(),
4347-
[](const SILResultInfo &resultInfo) {
4348-
return resultInfo.isFormalIndirect();
4349-
});
4345+
NumAnyIndirectFormalResults = 0;
4346+
NumPackResults = 0;
4347+
for (auto &resultInfo : normalResults) {
4348+
if (resultInfo.isFormalIndirect())
4349+
NumAnyIndirectFormalResults++;
4350+
if (resultInfo.isPack())
4351+
NumPackResults++;
4352+
}
43504353
memcpy(getMutableResults().data(), normalResults.data(),
43514354
normalResults.size() * sizeof(SILResultInfo));
43524355
} else {
4353-
assert(normalResults.empty());
4356+
assert(normalResults.empty());
43544357
NumAnyResults = yields.size();
4355-
NumAnyIndirectFormalResults = std::count_if(
4356-
yields.begin(), yields.end(), [](const SILYieldInfo &yieldInfo) {
4357-
return yieldInfo.isFormalIndirect();
4358-
});
4358+
NumAnyIndirectFormalResults = 0;
4359+
NumPackResults = 0;
4360+
for (auto &yieldInfo : yields) {
4361+
if (yieldInfo.isFormalIndirect())
4362+
NumAnyIndirectFormalResults++;
4363+
if (yieldInfo.isPack())
4364+
NumPackResults++;
4365+
}
43594366
memcpy(getMutableYields().data(), yields.data(),
43604367
yields.size() * sizeof(SILYieldInfo));
43614368
}

lib/AST/ParameterPack.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@ bool TupleType::containsPackExpansionType() const {
165165
return false;
166166
}
167167

168+
bool CanTupleType::containsPackExpansionTypeImpl(CanTupleType tuple) {
169+
for (auto eltType : tuple.getElementTypes()) {
170+
if (isa<PackExpansionType>(eltType))
171+
return true;
172+
}
173+
174+
return false;
175+
}
176+
168177
/// (W, {X, Y}..., Z) => (W, X, Y, Z)
169178
Type TupleType::flattenPackTypes() {
170179
bool anyChanged = false;
@@ -465,3 +474,17 @@ bool SILPackType::containsPackExpansionType() const {
465474

466475
return false;
467476
}
477+
478+
CanPackType
479+
CanTupleType::getInducedPackTypeImpl(CanTupleType tuple, unsigned start, unsigned count) {
480+
assert(start + count <= tuple->getNumElements() && "range out of range");
481+
482+
auto &ctx = tuple->getASTContext();
483+
if (count == 0) return CanPackType::get(ctx, {});
484+
485+
SmallVector<CanType, 4> eltTypes;
486+
eltTypes.reserve(count);
487+
for (unsigned i = start, e = start + count; i != e; ++i)
488+
eltTypes.push_back(tuple.getElementType(i));
489+
return CanPackType::get(ctx, eltTypes);
490+
}

lib/SIL/IR/SILFunction.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,16 @@ SILType SILFunction::getLoweredType(Type t) const {
472472
return getModule().Types.getLoweredType(t, TypeExpansionContext(*this));
473473
}
474474

475+
CanType
476+
SILFunction::getLoweredRValueType(AbstractionPattern orig, Type subst) const {
477+
return getModule().Types.getLoweredRValueType(TypeExpansionContext(*this),
478+
orig, subst);
479+
}
480+
481+
CanType SILFunction::getLoweredRValueType(Type t) const {
482+
return getModule().Types.getLoweredRValueType(TypeExpansionContext(*this), t);
483+
}
484+
475485
SILType SILFunction::getLoweredLoadableType(Type t) const {
476486
auto &M = getModule();
477487
return M.Types.getLoweredLoadableType(t, TypeExpansionContext(*this), M);

lib/SILGen/Conversion.h

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,15 @@ class ConvertingInitialization final : public Initialization {
250250
Finished,
251251

252252
/// The converted value has been extracted.
253-
Extracted
253+
Extracted,
254+
255+
/// We're doing pack initialization instead of the normal state
256+
/// transition, and we haven't been finished yet.
257+
PackExpanding,
258+
259+
/// We're doing pack initialization instead of the normal state
260+
/// transition, and finishInitialization has been called.
261+
FinishedPackExpanding,
254262
};
255263

256264
StateTy State;
@@ -280,6 +288,7 @@ class ConvertingInitialization final : public Initialization {
280288
FinalContext(SGFContext(subInitialization.get())) {
281289
OwnedSubInitialization = std::move(subInitialization);
282290
}
291+
283292

284293
/// Return the conversion to apply to the unconverted value.
285294
const Conversion &getConversion() const {
@@ -345,9 +354,26 @@ class ConvertingInitialization final : public Initialization {
345354

346355
// Bookkeeping.
347356
void finishInitialization(SILGenFunction &SGF) override {
348-
assert(getState() == Initialized);
349-
State = Finished;
357+
if (getState() == PackExpanding) {
358+
FinalContext.getEmitInto()->finishInitialization(SGF);
359+
State = FinishedPackExpanding;
360+
} else {
361+
assert(getState() == Initialized);
362+
State = Finished;
363+
}
350364
}
365+
366+
// Support pack-expansion initialization.
367+
bool canPerformPackExpansionInitialization() const override {
368+
if (auto finalInit = FinalContext.getEmitInto())
369+
return finalInit->canPerformPackExpansionInitialization();
370+
return false;
371+
}
372+
373+
void performPackExpansionInitialization(SILGenFunction &SGF,
374+
SILLocation loc,
375+
SILValue indexWithinComponent,
376+
llvm::function_ref<void(Initialization *into)> fn) override;
351377
};
352378

353379
} // end namespace Lowering

0 commit comments

Comments
 (0)