Skip to content

Commit 3342d67

Browse files
committed
Fix the creation of substituted abstraction patterns for expansions
1 parent e1c3988 commit 3342d67

File tree

4 files changed

+170
-63
lines changed

4 files changed

+170
-63
lines changed

include/swift/SIL/AbstractionPattern.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ class AbstractionPattern {
915915
// don't want to try to unique by Clang node.
916916
//
917917
// Even if we support Clang nodes someday, we *cannot* cache
918-
// by the open-coded patterns like Tuple and PackExpansion.
918+
// by the open-coded patterns like Tuple.
919919
return getKind() == Kind::Type || getKind() == Kind::Opaque
920920
|| getKind() == Kind::Discard;
921921
}
@@ -980,7 +980,8 @@ class AbstractionPattern {
980980
case Kind::Type:
981981
case Kind::ClangType:
982982
case Kind::Discard: {
983-
return getType()->isParameterPack();
983+
auto ty = getType();
984+
return isa<PackArchetypeType>(ty) || ty->isParameterPack();
984985
}
985986
default:
986987
return false;

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 161 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,12 +1218,15 @@ forEachFunctionParam(AnyFunctionType::CanParamArrayRef substParams,
12181218
// Honor ignoreFinalParam for the substituted parameters on all paths.
12191219
if (ignoreFinalParam) substParams = substParams.drop_back();
12201220

1221-
// If this isn't a function type, use the substituted type.
1222-
if (isTypeParameterOrOpaqueArchetype()) {
1221+
// If we don't have a function type, use the substituted type.
1222+
if (isTypeParameterOrOpaqueArchetype() ||
1223+
getKind() == Kind::OpaqueFunction ||
1224+
getKind() == Kind::OpaqueDerivativeFunction) {
12231225
for (auto substParamIndex : indices(substParams)) {
12241226
handleScalar(substParamIndex, substParamIndex,
12251227
substParams[substParamIndex].getParameterFlags(),
1226-
*this, substParams[substParamIndex]);
1228+
AbstractionPattern::getOpaque(),
1229+
substParams[substParamIndex]);
12271230
}
12281231
return;
12291232
}
@@ -1829,38 +1832,57 @@ class SubstFunctionTypePatternVisitor
18291832
SmallVector<Requirement, 2> substRequirements;
18301833
SmallVector<Type, 2> substReplacementTypes;
18311834
CanType substYieldType;
1835+
bool WithinExpansion = false;
18321836

18331837
SubstFunctionTypePatternVisitor(TypeConverter &TC)
18341838
: TC(TC) {}
1835-
1839+
18361840
// Creates and returns a fresh type parameter in the substituted generic
18371841
// signature if `pattern` is a type parameter or opaque archetype. Returns
18381842
// null otherwise.
1839-
CanType handleTypeParameterInAbstractionPattern(AbstractionPattern pattern,
1840-
CanType substTy) {
1843+
CanType handleTypeParameter(AbstractionPattern pattern, CanType substTy) {
18411844
if (!pattern.isTypeParameterOrOpaqueArchetype())
18421845
return CanType();
18431846

1844-
// If so, let's put a fresh generic parameter in the substituted signature
1845-
// here.
18461847
unsigned paramIndex = substGenericParams.size();
18471848

1848-
bool isParameterPack = false;
1849-
if (substTy->isParameterPack() || substTy->is<PackArchetypeType>())
1850-
isParameterPack = true;
1851-
else if (pattern.isTypeParameterPack())
1852-
isParameterPack = true;
1849+
// Pack parameters that aren't within expansions should just be
1850+
// abstracted as scalars.
1851+
bool isParameterPack = (WithinExpansion && pattern.isTypeParameterPack());
18531852

18541853
auto gp = GenericTypeParamType::get(isParameterPack, 0, paramIndex,
18551854
TC.Context);
18561855
substGenericParams.push_back(gp);
1857-
if (isParameterPack) {
1858-
substReplacementTypes.push_back(
1859-
PackType::getSingletonPackExpansion(substTy));
1856+
1857+
CanType replacement;
1858+
1859+
if (WithinExpansion) {
1860+
// If we're within an expansion, and there are substitutions in the
1861+
// abstraction pattern, use those instead of substTy. substTy is not
1862+
// contextually meaningful in this case; see handlePackExpansion.
1863+
if (auto subs = pattern.getGenericSubstitutions()) {
1864+
replacement = pattern.getType().subst(subs)->getCanonicalType();
1865+
1866+
// If we don't have substitutions, but we're abstracting a pack
1867+
// parameter, assume that we're lowering a function type using
1868+
// itself as its pattern or something like. The substituted type
1869+
// should be `each T` for some pack reference; wrap that in a pack.
1870+
} else if (isParameterPack) {
1871+
replacement = CanPackType::getSingletonPackExpansion(substTy);
1872+
1873+
// Otherwise, just use substTy.
1874+
} else {
1875+
replacement = substTy;
1876+
}
1877+
1878+
// Otherwise, we can just use substTy.
18601879
} else {
1861-
substReplacementTypes.push_back(substTy);
1880+
assert(!isParameterPack);
1881+
assert(!isa<PackType>(substTy));
1882+
replacement = substTy;
18621883
}
1863-
1884+
substReplacementTypes.push_back(replacement);
1885+
18641886
if (auto layout = pattern.getLayoutConstraint()) {
18651887
// Look at the layout constraint on this position in the abstraction pattern
18661888
// and carry it over, with some generalization to the point it affects
@@ -1914,7 +1936,7 @@ class SubstFunctionTypePatternVisitor
19141936
}
19151937

19161938
CanType visit(CanType t, AbstractionPattern pattern) {
1917-
if (auto gp = handleTypeParameterInAbstractionPattern(pattern, t))
1939+
if (auto gp = handleTypeParameter(pattern, t))
19181940
return gp;
19191941

19201942
return CanTypeVisitor::visit(t, pattern);
@@ -1960,7 +1982,7 @@ class SubstFunctionTypePatternVisitor
19601982
if (!orig->hasTypeParameter()
19611983
&& !orig->hasArchetype()
19621984
&& !orig->hasOpaqueArchetype()) {
1963-
return CanType(subst);
1985+
return subst;
19641986
}
19651987

19661988
// If the substituted type is a subclass of the abstraction pattern
@@ -2067,26 +2089,81 @@ class SubstFunctionTypePatternVisitor
20672089

20682090
CanType visitPackExpansionType(CanPackExpansionType pack,
20692091
AbstractionPattern pattern) {
2070-
// Avoid walking into the pattern and count type if we can help it.
2071-
if (!pack->hasTypeParameter() && !pack->hasArchetype() &&
2072-
!pack->hasOpaqueArchetype()) {
2073-
return CanType(pack);
2092+
llvm_unreachable("shouldn't encounter pack expansion by itself");
2093+
}
2094+
2095+
CanType handlePackExpansion(AbstractionPattern origExpansion,
2096+
CanType candidateSubstType) {
2097+
// When we're within a pack expansion, pack references matching that
2098+
// expansion should be abstracted as packs. The substitution will be
2099+
// the pack substitution for that parameter recorded in the pattern.
2100+
2101+
// Remember that we're within an expansion.
2102+
// FIXME: when we introduce PackReferenceType we'll need to be clear
2103+
// about which pack expansions to treat this way.
2104+
llvm::SaveAndRestore<bool> scope(WithinExpansion, true);
2105+
2106+
auto origPatternType = origExpansion.getPackExpansionPatternType();
2107+
2108+
// We only really need a subst type here if we don't have
2109+
// substitutions in the pattern, because handleTypeParameter
2110+
// will always those substitutions within an expansion if
2111+
// they're available. And if we don't have substitutions in the
2112+
// pattern, we can't map the pack expansion to a concrete set
2113+
// of expanded components, so we should have exactly one subst
2114+
// type.
2115+
CanType substPatternType;
2116+
if (origExpansion.getGenericSubstitutions()) {
2117+
substPatternType = origPatternType.getType();
2118+
} else {
2119+
assert(candidateSubstType);
2120+
substPatternType =
2121+
cast<PackExpansionType>(candidateSubstType).getPatternType();
20742122
}
20752123

2076-
auto substPatternType = visit(pack.getPatternType(),
2077-
pattern.getPackExpansionPatternType());
2078-
auto substCountType = visit(pack.getCountType(),
2079-
AbstractionPattern::getOpaque());
2124+
// Recursively visit the pattern type.
2125+
auto patternTy = visit(substPatternType, origPatternType);
20802126

2081-
SmallVector<Type> rootParameterPacks;
2082-
substPatternType->getTypeParameterPacks(rootParameterPacks);
2127+
// Find a pack parameter from the pattern to expand over.
2128+
auto countParam = findExpandedPackParameter(patternTy);
20832129

2084-
for (auto parameterPack : rootParameterPacks) {
2085-
substRequirements.emplace_back(RequirementKind::SameShape,
2086-
parameterPack, substCountType);
2087-
}
2130+
// If that didn't work, we should be able to find an expansion
2131+
// to use from either the substituted type or the subs. At worst,
2132+
// we can make one.
2133+
assert(countParam && "implementable but lazy");
20882134

2089-
return CanPackExpansionType::get(substPatternType, substCountType);
2135+
return CanPackExpansionType::get(patternTy, countParam);
2136+
}
2137+
2138+
static CanType findExpandedPackParameter(CanType patternType) {
2139+
struct Walker : public TypeWalker {
2140+
CanType Result;
2141+
Action walkToTypePre(Type _ty) override {
2142+
auto ty = CanType(_ty);
2143+
2144+
// Don't recurse inside pack expansions.
2145+
if (isa<PackExpansionType>(ty)) {
2146+
return Action::SkipChildren;
2147+
}
2148+
2149+
// Consider type parameters.
2150+
if (ty->isTypeParameter()) {
2151+
auto param = ty->getRootGenericParam();
2152+
if (param->isParameterPack()) {
2153+
Result = CanType(param);
2154+
return Action::Stop;
2155+
}
2156+
return Action::SkipChildren;
2157+
}
2158+
2159+
// Otherwise continue.
2160+
return Action::Continue;
2161+
}
2162+
};
2163+
2164+
Walker walker;
2165+
patternType.walk(walker);
2166+
return walker.Result;
20902167
}
20912168

20922169
CanType visitExistentialType(CanExistentialType exist,
@@ -2121,14 +2198,31 @@ class SubstFunctionTypePatternVisitor
21212198
}
21222199

21232200
CanType visitTupleType(CanTupleType tuple, AbstractionPattern pattern) {
2124-
// Break down the tuple.
2201+
assert(pattern.isTuple());
2202+
2203+
// It's pretty weird for us to end up in this case with an
2204+
// open-coded tuple pattern, but it happens with opaque derivative
2205+
// functions in autodiff.
2206+
CanTupleType origTupleTypeForLabels = pattern.getAs<TupleType>();
2207+
if (!origTupleTypeForLabels) origTupleTypeForLabels = tuple;
2208+
21252209
SmallVector<TupleTypeElt, 4> tupleElts;
2126-
for (unsigned i = 0; i < tuple->getNumElements(); ++i) {
2127-
auto elt = tuple->getElement(i);
2128-
auto substEltTy = visit(tuple.getElementType(i),
2129-
pattern.getTupleElementType(i));
2130-
tupleElts.emplace_back(substEltTy, elt.getName());
2131-
}
2210+
pattern.forEachTupleElement(tuple,
2211+
[&](unsigned origEltIndex, unsigned substEltIndex,
2212+
AbstractionPattern origEltType, CanType substEltType) {
2213+
auto eltTy = visit(substEltType, origEltType);
2214+
auto &origElt = origTupleTypeForLabels->getElement(origEltIndex);
2215+
tupleElts.push_back(origElt.getWithType(eltTy));
2216+
}, [&](unsigned origEltIndex, unsigned substEltIndex,
2217+
AbstractionPattern origExpansionType,
2218+
CanTupleEltTypeArrayRef substEltTypes) {
2219+
CanType candidateSubstType;
2220+
if (!substEltTypes.empty())
2221+
candidateSubstType = substEltTypes[0];
2222+
auto eltTy = handlePackExpansion(origExpansionType, candidateSubstType);
2223+
auto &origElt = origTupleTypeForLabels->getElement(origEltIndex);
2224+
tupleElts.push_back(origElt.getWithType(eltTy));
2225+
});
21322226

21332227
return CanType(TupleType::get(tupleElts, TC.Context));
21342228
}
@@ -2138,19 +2232,29 @@ class SubstFunctionTypePatternVisitor
21382232
CanType yieldType,
21392233
AbstractionPattern yieldPattern) {
21402234
SmallVector<FunctionType::Param, 4> newParams;
2141-
2142-
for (unsigned i = 0; i < func->getParams().size(); ++i) {
2143-
auto param = func->getParams()[i];
2144-
// Lower the formal type of the argument binding, eliminating variadicity.
2145-
auto newParamTy = visit(CanType(param.getParameterType(true)),
2146-
pattern.getFunctionParamType(i));
2147-
auto newParam = FunctionType::Param(newParamTy,
2148-
param.getLabel(),
2149-
param.getParameterFlags()
2150-
.withVariadic(false),
2151-
param.getInternalLabel());
2152-
newParams.push_back(newParam);
2153-
}
2235+
auto addParam = [&](ParameterTypeFlags oldFlags, CanType newType) {
2236+
newParams.push_back(FunctionType::Param(
2237+
newType, /*label*/ Identifier(), oldFlags.withVariadic(false),
2238+
/*internal label*/ Identifier()));
2239+
};
2240+
2241+
pattern.forEachFunctionParam(func.getParams(), /*ignore self*/ false,
2242+
[&](unsigned origParamIndex, unsigned substParamIndex,
2243+
ParameterTypeFlags origFlags, AbstractionPattern origParamType,
2244+
AnyFunctionType::CanParam substParam) {
2245+
auto newParamTy = visit(substParam.getParameterType(), origParamType);
2246+
addParam(origFlags, newParamTy);
2247+
}, [&](unsigned origParamIndex, unsigned substParamIndex,
2248+
ParameterTypeFlags origFlags,
2249+
AbstractionPattern origExpansionType,
2250+
AnyFunctionType::CanParamArrayRef substParams) {
2251+
CanType candidateSubstType;
2252+
if (!substParams.empty())
2253+
candidateSubstType = substParams[0].getParameterType();
2254+
auto expansionType =
2255+
handlePackExpansion(origExpansionType, candidateSubstType);
2256+
addParam(origFlags, expansionType);
2257+
});
21542258

21552259
if (yieldType) {
21562260
substYieldType = visit(yieldType, yieldPattern);
@@ -2229,9 +2333,9 @@ const {
22292333
yieldType = yieldType->getReducedType(substSig);
22302334

22312335
return std::make_tuple(
2232-
AbstractionPattern(substSig, substTy->getReducedType(substSig)),
2336+
AbstractionPattern(subMap, substSig, substTy->getReducedType(substSig)),
22332337
subMap,
22342338
yieldType
2235-
? AbstractionPattern(substSig, yieldType)
2339+
? AbstractionPattern(subMap, substSig, yieldType)
22362340
: AbstractionPattern::getInvalid());
22372341
}

test/SILGen/pack_expansion_type.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
public func variadicFunction<each T, each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) {}
99

1010
public struct VariadicType<each T> {
11-
// CHECK-LABEL: sil [ossa] @$s19pack_expansion_type12VariadicTypeV14variadicMethod1t1ux_qd__txQp_txxQp_qd__xQptqd__RhzRvd__lF : $@convention(method) <each T><each U where (repeat (each T, each U)) : Any> (@pack_guaranteed Pack{repeat each T}, @pack_guaranteed Pack{repeat each U}, VariadicType<repeat each T>) -> @pack_out Pack{repeat (each T, each U)} {
11+
// CHECK-LABEL: sil [ossa] @$s19pack_expansion_type12VariadicTypeV14variadicMethod1t1ux_qd__txQp_txxQp_qd__xQptqd__RhzRvd__lF :
12+
// CHECK-SAME: $@convention(method) <each T><each U where (repeat (each T, each U)) : Any> (@pack_guaranteed Pack{repeat each T}, @pack_guaranteed Pack{repeat each U}, VariadicType<repeat each T>) -> @pack_out Pack{repeat (each T, each U)} {
1213
// CHECK: bb0(%0 : $*Pack{repeat (each T, each U)}, %1 : $*Pack{repeat each T}, %2 : $*Pack{repeat each U}, %3 : $VariadicType<repeat each T>):
1314
public func variadicMethod<each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) {}
1415

15-
// CHECK-LABEL: sil [ossa] @$s19pack_expansion_type12VariadicTypeV13takesFunction1tyqd__qd__Qp_txxQpXE_tRvd__lF : $@convention(method) <each T><each U> (@guaranteed @noescape @callee_guaranteed @substituted <each τ_0_0, each τ_0_1, each τ_0_2, each τ_0_3 where (repeat (each τ_0_0, each τ_0_1)) : Any, (repeat (each τ_0_2, each τ_0_3)) : Any> (@pack_guaranteed Pack{repeat each τ_0_0}) -> @pack_out Pack{repeat each τ_0_2} for <Pack{repeat each T}, Pack{repeat each T}, Pack{repeat each U}, Pack{repeat each U}>, VariadicType<repeat each T>) -> () {
16-
// CHECK: bb0(%0 : @guaranteed $@noescape @callee_guaranteed @substituted <each τ_0_0, each τ_0_1, each τ_0_2, each τ_0_3 where (repeat (each τ_0_0, each τ_0_1)) : Any, (repeat (each τ_0_2, each τ_0_3)) : Any> (@pack_guaranteed Pack{repeat each τ_0_0}) -> @pack_out Pack{repeat each τ_0_2} for <Pack{repeat each T}, Pack{repeat each T}, Pack{repeat each U}, Pack{repeat each U}>, %1 : $VariadicType<repeat each T>):
16+
// CHECK-LABEL: sil [ossa] @$s19pack_expansion_type12VariadicTypeV13takesFunction1tyqd__qd__Qp_txxQpXE_tRvd__lF :
17+
// CHECK-SAME: $@convention(method) <each T><each U> (@guaranteed @noescape @callee_guaranteed @substituted <each τ_0_0, each τ_0_1> (@pack_guaranteed Pack{repeat each τ_0_0}) -> @pack_out Pack{repeat each τ_0_1} for <Pack{repeat each T}, Pack{repeat each U}>, VariadicType<repeat each T>) -> () {
18+
// CHECK: bb0(%0 : @guaranteed $@noescape @callee_guaranteed @substituted <each τ_0_0, each τ_0_1> (@pack_guaranteed Pack{repeat each τ_0_0}) -> @pack_out Pack{repeat each τ_0_1} for <Pack{repeat each T}, Pack{repeat each U}>, %1 : $VariadicType<repeat each T>):
1719
public func takesFunction<each U>(t: (repeat each T) -> (repeat each U)) {}
1820
}
1921

test/SILGen/variadic-generic-closures.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
public struct G<T> {}
77

8-
// CHECK-LABEL: sil [ossa] @$s4main6caller2fnyyAA1GVyxGxQpXE_tRvzlF : $@convention(thin) <each T> (@guaranteed @noescape @callee_guaranteed @substituted <each τ_0_0, each τ_0_1 where (repeat (each τ_0_0, each τ_0_1)) : Any> (@pack_guaranteed Pack{repeat G<each τ_0_0>}) -> () for <Pack{repeat each T}, Pack{repeat each T}>) -> () {
8+
// CHECK-LABEL: sil [ossa] @$s4main6caller2fnyyAA1GVyxGxQpXE_tRvzlF : $@convention(thin) <each T> (@guaranteed @noescape @callee_guaranteed @substituted <each τ_0_0> (@pack_guaranteed Pack{repeat G<each τ_0_0>}) -> () for <Pack{repeat each T}>) -> () {
99
public func caller<each T>(fn: (repeat G<each T>) -> ()) {
1010
fn(repeat G<each T>())
1111
}

0 commit comments

Comments
 (0)