@@ -180,44 +180,108 @@ static ValueDecl *rewriteIntegerTypes(SubstitutionMap subst, ValueDecl *oldDecl,
180
180
return newDecl;
181
181
}
182
182
183
- // Derive a concrete function type for fdecl by substituting the generic args
184
- // and use that to derive the corresponding function type and parameter list.
185
- static std::pair<FunctionType *, ParameterList *>
186
- substituteFunctionTypeAndParamList (ASTContext &ctx, AbstractFunctionDecl *fdecl,
187
- SubstitutionMap subst) {
188
- FunctionType *newFnType = nullptr ;
189
- // Create a new ParameterList with the substituted type.
190
- if (auto oldFnType = dyn_cast<GenericFunctionType>(
191
- fdecl->getInterfaceType ().getPointer ())) {
192
- newFnType = oldFnType->substGenericArgs (subst);
193
- } else {
194
- newFnType = cast<FunctionType>(fdecl->getInterfaceType ().getPointer ());
183
+ // Synthesize a thunk body for the function created in
184
+ // "addThunkForDependentTypes". This will just cast all params and forward them
185
+ // along to the specialized function. It will also cast the result before
186
+ // returning it.
187
+ static std::pair<BraceStmt *, bool >
188
+ synthesizeDependentTypeThunkParamForwarding (AbstractFunctionDecl *afd, void *context) {
189
+ ASTContext &ctx = afd->getASTContext ();
190
+
191
+ auto thunkDecl = cast<FuncDecl>(afd);
192
+ auto specializedFuncDecl = static_cast <FuncDecl *>(context);
193
+
194
+ SmallVector<Argument, 8 > forwardingParams;
195
+ unsigned paramIndex = 0 ;
196
+ for (auto param : *thunkDecl->getParameters ()) {
197
+ if (isa<MetatypeType>(param->getType ().getPointer ())) {
198
+ paramIndex++;
199
+ continue ;
200
+ }
201
+
202
+ auto paramRefExpr = new (ctx) DeclRefExpr (param, DeclNameLoc (),
203
+ /* Implicit=*/ true );
204
+ paramRefExpr->setType (param->getType ());
205
+
206
+ auto specParamTy = specializedFuncDecl->getParameters ()->get (paramIndex)->getType ();
207
+ auto cast = ForcedCheckedCastExpr::createImplicit (
208
+ ctx, paramRefExpr, specParamTy);
209
+
210
+ forwardingParams.push_back (Argument (SourceLoc (), Identifier (), cast));
211
+ paramIndex++;
195
212
}
196
- // The constructor type is a function type as follows:
197
- // (CType.Type) -> (Generic) -> CType
198
- // And a method's function type is as follows:
199
- // (inout CType) -> (Generic) -> Void
200
- // In either case, we only want the result of that function type because that
201
- // is the function type with the generic params that need to be substituted:
202
- // (Generic) -> CType
203
- if (isa<ConstructorDecl>(fdecl) || fdecl->isInstanceMember () ||
204
- fdecl->isStatic ())
205
- newFnType = cast<FunctionType>(newFnType->getResult ().getPointer ());
206
- SmallVector<ParamDecl *, 4 > newParams;
207
- unsigned i = 0 ;
208
213
209
- for (auto paramTy : newFnType->getParams ()) {
210
- auto *oldParamDecl = fdecl->getParameters ()->get (i);
211
- auto *newParamDecl =
212
- ParamDecl::cloneWithoutType (fdecl->getASTContext (), oldParamDecl);
213
- newParamDecl->setInterfaceType (paramTy.getParameterType ());
214
- newParams.push_back (newParamDecl);
215
- (void )++i;
214
+ auto *specializedFuncDeclRef = new (ctx) DeclRefExpr (ConcreteDeclRef (specializedFuncDecl),
215
+ DeclNameLoc (), true );
216
+ specializedFuncDeclRef->setType (specializedFuncDecl->getInterfaceType ());
217
+
218
+ auto argList = ArgumentList::createImplicit (ctx, forwardingParams);
219
+ auto *specializedFuncCallExpr = CallExpr::createImplicit (ctx, specializedFuncDeclRef, argList);
220
+ specializedFuncCallExpr->setType (specializedFuncDecl->getResultInterfaceType ());
221
+ specializedFuncCallExpr->setThrows (false );
222
+
223
+ auto cast = ForcedCheckedCastExpr::createImplicit (
224
+ ctx, specializedFuncCallExpr, thunkDecl->getResultInterfaceType ());
225
+
226
+ auto returnStmt = new (ctx) ReturnStmt (SourceLoc (), cast, /* implicit=*/ true );
227
+ auto body = BraceStmt::create (ctx, SourceLoc (), {returnStmt}, SourceLoc (),
228
+ /* implicit=*/ true );
229
+ return {body, /* isTypeChecked=*/ true };
230
+ }
231
+
232
+ // Create a thunk to map functions with dependent types to their specialized
233
+ // version. For example, create a thunk with type (Any) -> Any to wrap a
234
+ // specialized function template with type (Dependent<T>) -> Dependent<T>.
235
+ static ValueDecl *addThunkForDependentTypes (FuncDecl *oldDecl,
236
+ FuncDecl *newDecl) {
237
+ bool updatedAnyParams = false ;
238
+
239
+ SmallVector<ParamDecl *, 4 > fixedParameters;
240
+ unsigned parameterIndex = 0 ;
241
+ for (auto *newFnParam : *newDecl->getParameters ()) {
242
+ // If the un-specialized function had a parameter with type "Any" preserve
243
+ // that parameter. Otherwise, use the new function parameter.
244
+ auto oldParamType = oldDecl->getParameters ()->get (parameterIndex)->getType ();
245
+ if (oldParamType->isEqual (newDecl->getASTContext ().TheAnyType )) {
246
+ updatedAnyParams = true ;
247
+ auto newParam =
248
+ ParamDecl::cloneWithoutType (newDecl->getASTContext (), newFnParam);
249
+ newParam->setInterfaceType (oldParamType);
250
+ fixedParameters.push_back (newParam);
251
+ } else {
252
+ fixedParameters.push_back (newFnParam);
253
+ }
254
+ parameterIndex++;
216
255
}
217
- auto *newParamList =
218
- ParameterList::create (ctx, SourceLoc (), newParams, SourceLoc ());
219
256
220
- return {newFnType, newParamList};
257
+ // If we don't need this thunk, bail out.
258
+ if (!updatedAnyParams &&
259
+ !oldDecl->getResultInterfaceType ()->isEqual (
260
+ oldDecl->getASTContext ().TheAnyType ))
261
+ return newDecl;
262
+
263
+ auto fixedParams =
264
+ ParameterList::create (newDecl->getASTContext (), fixedParameters);
265
+
266
+ Type fixedResultType;
267
+ if (oldDecl->getResultInterfaceType ()->isEqual (
268
+ oldDecl->getASTContext ().TheAnyType ))
269
+ fixedResultType = oldDecl->getASTContext ().TheAnyType ;
270
+ else
271
+ fixedResultType = newDecl->getResultInterfaceType ();
272
+
273
+ // We have to rebuild the whole function.
274
+ auto newFnDecl = FuncDecl::createImplicit (
275
+ newDecl->getASTContext (), newDecl->getStaticSpelling (),
276
+ newDecl->getName (), newDecl->getNameLoc (), newDecl->hasAsync (),
277
+ newDecl->hasThrows (), /* genericParams=*/ nullptr , fixedParams,
278
+ fixedResultType, newDecl->getDeclContext ());
279
+ newFnDecl->copyFormalAccessFrom (newDecl);
280
+ newFnDecl->setBodySynthesizer (synthesizeDependentTypeThunkParamForwarding, newDecl);
281
+ newFnDecl->setSelfAccessKind (newDecl->getSelfAccessKind ());
282
+ newFnDecl->getAttrs ().add (
283
+ new (newDecl->getASTContext ()) TransparentAttr (/* IsImplicit=*/ true ));
284
+ return newFnDecl;
221
285
}
222
286
223
287
// Synthesizes the body of a thunk that takes extra metatype arguments and
@@ -268,16 +332,55 @@ static ValueDecl *generateThunkForExtraMetatypes(SubstitutionMap subst,
268
332
// specialization, which are no longer now that we've specialized
269
333
// this function. Create a thunk that only forwards the original
270
334
// parameters along to the clang function.
271
- auto thunkTypeAndParamList = substituteFunctionTypeAndParamList (oldDecl->getASTContext (),
272
- oldDecl, subst);
335
+ SmallVector<ParamDecl *, 4 > newParams;
336
+
337
+ for (auto param : *newDecl->getParameters ()) {
338
+ auto *newParamDecl = ParamDecl::clone (newDecl->getASTContext (), param);
339
+ newParams.push_back (newParamDecl);
340
+ }
341
+
342
+ auto originalFnSubst = cast<AbstractFunctionDecl>(oldDecl)
343
+ ->getInterfaceType ()
344
+ ->getAs <GenericFunctionType>()
345
+ ->substGenericArgs (subst);
346
+ // The constructor type is a function type as follows:
347
+ // (CType.Type) -> (Generic) -> CType
348
+ // And a method's function type is as follows:
349
+ // (inout CType) -> (Generic) -> Void
350
+ // In either case, we only want the result of that function type because that
351
+ // is the function type with the generic params that need to be substituted:
352
+ // (Generic) -> CType
353
+ if (isa<ConstructorDecl>(oldDecl) || oldDecl->isInstanceMember () ||
354
+ oldDecl->isStatic ())
355
+ originalFnSubst = cast<FunctionType>(originalFnSubst->getResult ().getPointer ());
356
+
357
+ for (auto paramTy : originalFnSubst->getParams ()) {
358
+ if (!paramTy.getPlainType ()->is <MetatypeType>())
359
+ continue ;
360
+
361
+ auto dc = newDecl->getDeclContext ();
362
+ auto paramVarDecl =
363
+ new (newDecl->getASTContext ()) ParamDecl (
364
+ SourceLoc (), SourceLoc (), Identifier (), SourceLoc (),
365
+ newDecl->getASTContext ().getIdentifier (" _" ), dc);
366
+ paramVarDecl->setInterfaceType (paramTy.getPlainType ());
367
+ paramVarDecl->setSpecifier (ParamSpecifier::Default);
368
+ newParams.push_back (paramVarDecl);
369
+ }
370
+
371
+ auto *newParamList =
372
+ ParameterList::create (newDecl->getASTContext (), SourceLoc (), newParams, SourceLoc ());
373
+
273
374
auto thunk = FuncDecl::createImplicit (
274
- oldDecl ->getASTContext (), oldDecl ->getStaticSpelling (), oldDecl->getName (),
275
- oldDecl ->getNameLoc (), oldDecl ->hasAsync (), oldDecl ->hasThrows (),
276
- /* genericParams=*/ nullptr , thunkTypeAndParamList. second ,
277
- thunkTypeAndParamList. first -> getResult (), oldDecl ->getDeclContext ());
278
- thunk->copyFormalAccessFrom (oldDecl );
375
+ newDecl ->getASTContext (), newDecl ->getStaticSpelling (), oldDecl->getName (),
376
+ newDecl ->getNameLoc (), newDecl ->hasAsync (), newDecl ->hasThrows (),
377
+ /* genericParams=*/ nullptr , newParamList ,
378
+ newDecl-> getResultInterfaceType (), newDecl ->getDeclContext ());
379
+ thunk->copyFormalAccessFrom (newDecl );
279
380
thunk->setBodySynthesizer (synthesizeForwardingThunkBody, newDecl);
280
- thunk->setSelfAccessKind (oldDecl->getSelfAccessKind ());
381
+ thunk->setSelfAccessKind (newDecl->getSelfAccessKind ());
382
+ thunk->getAttrs ().add (
383
+ new (newDecl->getASTContext ()) TransparentAttr (/* IsImplicit=*/ true ));
281
384
282
385
return thunk;
283
386
}
@@ -327,6 +430,10 @@ static ConcreteDeclRef getCXXFunctionTemplateSpecialization(SubstitutionMap subs
327
430
}
328
431
}
329
432
433
+ if (auto fn = dyn_cast<FuncDecl>(decl)) {
434
+ newDecl = addThunkForDependentTypes (fn, cast<FuncDecl>(newDecl));
435
+ }
436
+
330
437
if (auto fn = dyn_cast<FuncDecl>(decl)) {
331
438
if (newFn->getNumParams () != fn->getParameters ()->size ()) {
332
439
newDecl = generateThunkForExtraMetatypes (subst, fn,
0 commit comments