@@ -1218,12 +1218,15 @@ forEachFunctionParam(AnyFunctionType::CanParamArrayRef substParams,
1218
1218
// Honor ignoreFinalParam for the substituted parameters on all paths.
1219
1219
if (ignoreFinalParam) substParams = substParams.drop_back ();
1220
1220
1221
- // If this isn't a function type, use the substituted type.
1222
- if (isTypeParameterOrOpaqueArchetype ()) {
1221
+ // If we don't have a function type, use the substituted type.
1222
+ if (isTypeParameterOrOpaqueArchetype () ||
1223
+ getKind () == Kind::OpaqueFunction ||
1224
+ getKind () == Kind::OpaqueDerivativeFunction) {
1223
1225
for (auto substParamIndex : indices (substParams)) {
1224
1226
handleScalar (substParamIndex, substParamIndex,
1225
1227
substParams[substParamIndex].getParameterFlags (),
1226
- *this , substParams[substParamIndex]);
1228
+ AbstractionPattern::getOpaque (),
1229
+ substParams[substParamIndex]);
1227
1230
}
1228
1231
return ;
1229
1232
}
@@ -1829,38 +1832,57 @@ class SubstFunctionTypePatternVisitor
1829
1832
SmallVector<Requirement, 2 > substRequirements;
1830
1833
SmallVector<Type, 2 > substReplacementTypes;
1831
1834
CanType substYieldType;
1835
+ bool WithinExpansion = false ;
1832
1836
1833
1837
SubstFunctionTypePatternVisitor (TypeConverter &TC)
1834
1838
: TC(TC) {}
1835
-
1839
+
1836
1840
// Creates and returns a fresh type parameter in the substituted generic
1837
1841
// signature if `pattern` is a type parameter or opaque archetype. Returns
1838
1842
// null otherwise.
1839
- CanType handleTypeParameterInAbstractionPattern (AbstractionPattern pattern,
1840
- CanType substTy) {
1843
+ CanType handleTypeParameter (AbstractionPattern pattern, CanType substTy) {
1841
1844
if (!pattern.isTypeParameterOrOpaqueArchetype ())
1842
1845
return CanType ();
1843
1846
1844
- // If so, let's put a fresh generic parameter in the substituted signature
1845
- // here.
1846
1847
unsigned paramIndex = substGenericParams.size ();
1847
1848
1848
- bool isParameterPack = false ;
1849
- if (substTy->isParameterPack () || substTy->is <PackArchetypeType>())
1850
- isParameterPack = true ;
1851
- else if (pattern.isTypeParameterPack ())
1852
- isParameterPack = true ;
1849
+ // Pack parameters that aren't within expansions should just be
1850
+ // abstracted as scalars.
1851
+ bool isParameterPack = (WithinExpansion && pattern.isTypeParameterPack ());
1853
1852
1854
1853
auto gp = GenericTypeParamType::get (isParameterPack, 0 , paramIndex,
1855
1854
TC.Context );
1856
1855
substGenericParams.push_back (gp);
1857
- if (isParameterPack) {
1858
- substReplacementTypes.push_back (
1859
- PackType::getSingletonPackExpansion (substTy));
1856
+
1857
+ CanType replacement;
1858
+
1859
+ if (WithinExpansion) {
1860
+ // If we're within an expansion, and there are substitutions in the
1861
+ // abstraction pattern, use those instead of substTy. substTy is not
1862
+ // contextually meaningful in this case; see handlePackExpansion.
1863
+ if (auto subs = pattern.getGenericSubstitutions ()) {
1864
+ replacement = pattern.getType ().subst (subs)->getCanonicalType ();
1865
+
1866
+ // If we don't have substitutions, but we're abstracting a pack
1867
+ // parameter, assume that we're lowering a function type using
1868
+ // itself as its pattern or something like. The substituted type
1869
+ // should be `each T` for some pack reference; wrap that in a pack.
1870
+ } else if (isParameterPack) {
1871
+ replacement = CanPackType::getSingletonPackExpansion (substTy);
1872
+
1873
+ // Otherwise, just use substTy.
1874
+ } else {
1875
+ replacement = substTy;
1876
+ }
1877
+
1878
+ // Otherwise, we can just use substTy.
1860
1879
} else {
1861
- substReplacementTypes.push_back (substTy);
1880
+ assert (!isParameterPack);
1881
+ assert (!isa<PackType>(substTy));
1882
+ replacement = substTy;
1862
1883
}
1863
-
1884
+ substReplacementTypes.push_back (replacement);
1885
+
1864
1886
if (auto layout = pattern.getLayoutConstraint ()) {
1865
1887
// Look at the layout constraint on this position in the abstraction pattern
1866
1888
// and carry it over, with some generalization to the point it affects
@@ -1914,7 +1936,7 @@ class SubstFunctionTypePatternVisitor
1914
1936
}
1915
1937
1916
1938
CanType visit (CanType t, AbstractionPattern pattern) {
1917
- if (auto gp = handleTypeParameterInAbstractionPattern (pattern, t))
1939
+ if (auto gp = handleTypeParameter (pattern, t))
1918
1940
return gp;
1919
1941
1920
1942
return CanTypeVisitor::visit (t, pattern);
@@ -1960,7 +1982,7 @@ class SubstFunctionTypePatternVisitor
1960
1982
if (!orig->hasTypeParameter ()
1961
1983
&& !orig->hasArchetype ()
1962
1984
&& !orig->hasOpaqueArchetype ()) {
1963
- return CanType ( subst) ;
1985
+ return subst;
1964
1986
}
1965
1987
1966
1988
// If the substituted type is a subclass of the abstraction pattern
@@ -2067,26 +2089,81 @@ class SubstFunctionTypePatternVisitor
2067
2089
2068
2090
CanType visitPackExpansionType (CanPackExpansionType pack,
2069
2091
AbstractionPattern pattern) {
2070
- // Avoid walking into the pattern and count type if we can help it.
2071
- if (!pack->hasTypeParameter () && !pack->hasArchetype () &&
2072
- !pack->hasOpaqueArchetype ()) {
2073
- return CanType (pack);
2092
+ llvm_unreachable (" shouldn't encounter pack expansion by itself" );
2093
+ }
2094
+
2095
+ CanType handlePackExpansion (AbstractionPattern origExpansion,
2096
+ CanType candidateSubstType) {
2097
+ // When we're within a pack expansion, pack references matching that
2098
+ // expansion should be abstracted as packs. The substitution will be
2099
+ // the pack substitution for that parameter recorded in the pattern.
2100
+
2101
+ // Remember that we're within an expansion.
2102
+ // FIXME: when we introduce PackReferenceType we'll need to be clear
2103
+ // about which pack expansions to treat this way.
2104
+ llvm::SaveAndRestore<bool > scope (WithinExpansion, true );
2105
+
2106
+ auto origPatternType = origExpansion.getPackExpansionPatternType ();
2107
+
2108
+ // We only really need a subst type here if we don't have
2109
+ // substitutions in the pattern, because handleTypeParameter
2110
+ // will always those substitutions within an expansion if
2111
+ // they're available. And if we don't have substitutions in the
2112
+ // pattern, we can't map the pack expansion to a concrete set
2113
+ // of expanded components, so we should have exactly one subst
2114
+ // type.
2115
+ CanType substPatternType;
2116
+ if (origExpansion.getGenericSubstitutions ()) {
2117
+ substPatternType = origPatternType.getType ();
2118
+ } else {
2119
+ assert (candidateSubstType);
2120
+ substPatternType =
2121
+ cast<PackExpansionType>(candidateSubstType).getPatternType ();
2074
2122
}
2075
2123
2076
- auto substPatternType = visit (pack.getPatternType (),
2077
- pattern.getPackExpansionPatternType ());
2078
- auto substCountType = visit (pack.getCountType (),
2079
- AbstractionPattern::getOpaque ());
2124
+ // Recursively visit the pattern type.
2125
+ auto patternTy = visit (substPatternType, origPatternType);
2080
2126
2081
- SmallVector<Type> rootParameterPacks;
2082
- substPatternType-> getTypeParameterPacks (rootParameterPacks );
2127
+ // Find a pack parameter from the pattern to expand over.
2128
+ auto countParam = findExpandedPackParameter (patternTy );
2083
2129
2084
- for ( auto parameterPack : rootParameterPacks) {
2085
- substRequirements. emplace_back (RequirementKind::SameShape ,
2086
- parameterPack, substCountType);
2087
- }
2130
+ // If that didn't work, we should be able to find an expansion
2131
+ // to use from either the substituted type or the subs. At worst ,
2132
+ // we can make one.
2133
+ assert (countParam && " implementable but lazy " );
2088
2134
2089
- return CanPackExpansionType::get (substPatternType, substCountType);
2135
+ return CanPackExpansionType::get (patternTy, countParam);
2136
+ }
2137
+
2138
+ static CanType findExpandedPackParameter (CanType patternType) {
2139
+ struct Walker : public TypeWalker {
2140
+ CanType Result;
2141
+ Action walkToTypePre (Type _ty) override {
2142
+ auto ty = CanType (_ty);
2143
+
2144
+ // Don't recurse inside pack expansions.
2145
+ if (isa<PackExpansionType>(ty)) {
2146
+ return Action::SkipChildren;
2147
+ }
2148
+
2149
+ // Consider type parameters.
2150
+ if (ty->isTypeParameter ()) {
2151
+ auto param = ty->getRootGenericParam ();
2152
+ if (param->isParameterPack ()) {
2153
+ Result = CanType (param);
2154
+ return Action::Stop;
2155
+ }
2156
+ return Action::SkipChildren;
2157
+ }
2158
+
2159
+ // Otherwise continue.
2160
+ return Action::Continue;
2161
+ }
2162
+ };
2163
+
2164
+ Walker walker;
2165
+ patternType.walk (walker);
2166
+ return walker.Result ;
2090
2167
}
2091
2168
2092
2169
CanType visitExistentialType (CanExistentialType exist,
@@ -2121,14 +2198,31 @@ class SubstFunctionTypePatternVisitor
2121
2198
}
2122
2199
2123
2200
CanType visitTupleType (CanTupleType tuple, AbstractionPattern pattern) {
2124
- // Break down the tuple.
2201
+ assert (pattern.isTuple ());
2202
+
2203
+ // It's pretty weird for us to end up in this case with an
2204
+ // open-coded tuple pattern, but it happens with opaque derivative
2205
+ // functions in autodiff.
2206
+ CanTupleType origTupleTypeForLabels = pattern.getAs <TupleType>();
2207
+ if (!origTupleTypeForLabels) origTupleTypeForLabels = tuple;
2208
+
2125
2209
SmallVector<TupleTypeElt, 4 > tupleElts;
2126
- for (unsigned i = 0 ; i < tuple->getNumElements (); ++i) {
2127
- auto elt = tuple->getElement (i);
2128
- auto substEltTy = visit (tuple.getElementType (i),
2129
- pattern.getTupleElementType (i));
2130
- tupleElts.emplace_back (substEltTy, elt.getName ());
2131
- }
2210
+ pattern.forEachTupleElement (tuple,
2211
+ [&](unsigned origEltIndex, unsigned substEltIndex,
2212
+ AbstractionPattern origEltType, CanType substEltType) {
2213
+ auto eltTy = visit (substEltType, origEltType);
2214
+ auto &origElt = origTupleTypeForLabels->getElement (origEltIndex);
2215
+ tupleElts.push_back (origElt.getWithType (eltTy));
2216
+ }, [&](unsigned origEltIndex, unsigned substEltIndex,
2217
+ AbstractionPattern origExpansionType,
2218
+ CanTupleEltTypeArrayRef substEltTypes) {
2219
+ CanType candidateSubstType;
2220
+ if (!substEltTypes.empty ())
2221
+ candidateSubstType = substEltTypes[0 ];
2222
+ auto eltTy = handlePackExpansion (origExpansionType, candidateSubstType);
2223
+ auto &origElt = origTupleTypeForLabels->getElement (origEltIndex);
2224
+ tupleElts.push_back (origElt.getWithType (eltTy));
2225
+ });
2132
2226
2133
2227
return CanType (TupleType::get (tupleElts, TC.Context ));
2134
2228
}
@@ -2138,19 +2232,29 @@ class SubstFunctionTypePatternVisitor
2138
2232
CanType yieldType,
2139
2233
AbstractionPattern yieldPattern) {
2140
2234
SmallVector<FunctionType::Param, 4 > newParams;
2141
-
2142
- for (unsigned i = 0 ; i < func->getParams ().size (); ++i) {
2143
- auto param = func->getParams ()[i];
2144
- // Lower the formal type of the argument binding, eliminating variadicity.
2145
- auto newParamTy = visit (CanType (param.getParameterType (true )),
2146
- pattern.getFunctionParamType (i));
2147
- auto newParam = FunctionType::Param (newParamTy,
2148
- param.getLabel (),
2149
- param.getParameterFlags ()
2150
- .withVariadic (false ),
2151
- param.getInternalLabel ());
2152
- newParams.push_back (newParam);
2153
- }
2235
+ auto addParam = [&](ParameterTypeFlags oldFlags, CanType newType) {
2236
+ newParams.push_back (FunctionType::Param (
2237
+ newType, /* label*/ Identifier (), oldFlags.withVariadic (false ),
2238
+ /* internal label*/ Identifier ()));
2239
+ };
2240
+
2241
+ pattern.forEachFunctionParam (func.getParams (), /* ignore self*/ false ,
2242
+ [&](unsigned origParamIndex, unsigned substParamIndex,
2243
+ ParameterTypeFlags origFlags, AbstractionPattern origParamType,
2244
+ AnyFunctionType::CanParam substParam) {
2245
+ auto newParamTy = visit (substParam.getParameterType (), origParamType);
2246
+ addParam (origFlags, newParamTy);
2247
+ }, [&](unsigned origParamIndex, unsigned substParamIndex,
2248
+ ParameterTypeFlags origFlags,
2249
+ AbstractionPattern origExpansionType,
2250
+ AnyFunctionType::CanParamArrayRef substParams) {
2251
+ CanType candidateSubstType;
2252
+ if (!substParams.empty ())
2253
+ candidateSubstType = substParams[0 ].getParameterType ();
2254
+ auto expansionType =
2255
+ handlePackExpansion (origExpansionType, candidateSubstType);
2256
+ addParam (origFlags, expansionType);
2257
+ });
2154
2258
2155
2259
if (yieldType) {
2156
2260
substYieldType = visit (yieldType, yieldPattern);
@@ -2229,9 +2333,9 @@ const {
2229
2333
yieldType = yieldType->getReducedType (substSig);
2230
2334
2231
2335
return std::make_tuple (
2232
- AbstractionPattern (substSig, substTy->getReducedType (substSig)),
2336
+ AbstractionPattern (subMap, substSig, substTy->getReducedType (substSig)),
2233
2337
subMap,
2234
2338
yieldType
2235
- ? AbstractionPattern (substSig, yieldType)
2339
+ ? AbstractionPattern (subMap, substSig, yieldType)
2236
2340
: AbstractionPattern::getInvalid ());
2237
2341
}
0 commit comments