@@ -249,6 +249,212 @@ CanSILFunctionType buildThunkType(SILFunction *fn,
249
249
fn->getASTContext ());
250
250
}
251
251
252
+ SILFunction *getOrCreateReabstractionThunk (SILOptFunctionBuilder &fb,
253
+ SILModule &module , SILLocation loc,
254
+ SILFunction *caller,
255
+ CanSILFunctionType fromType,
256
+ CanSILFunctionType toType) {
257
+ assert (!fromType->getCombinedSubstitutions ());
258
+ assert (!toType->getCombinedSubstitutions ());
259
+
260
+ SubstitutionMap interfaceSubs;
261
+ GenericEnvironment *genericEnv = nullptr ;
262
+ auto thunkType =
263
+ buildThunkType (caller, fromType, toType, genericEnv, interfaceSubs,
264
+ /* withoutActuallyEscaping*/ false ,
265
+ DifferentiationThunkKind::Reabstraction);
266
+ auto thunkDeclType =
267
+ thunkType->getWithExtInfo (thunkType->getExtInfo ().withNoEscape (false ));
268
+
269
+ auto fromInterfaceType = fromType->mapTypeOutOfContext ()->getCanonicalType ();
270
+ auto toInterfaceType = toType->mapTypeOutOfContext ()->getCanonicalType ();
271
+
272
+ Mangle::ASTMangler mangler;
273
+ std::string name = mangler.mangleReabstractionThunkHelper (
274
+ thunkType, fromInterfaceType, toInterfaceType, Type (),
275
+ module .getSwiftModule ());
276
+
277
+ // FIXME(TF-989): Mark reabstraction thunks as transparent. This requires
278
+ // generating ossa reabstraction thunks so that they can be inlined during
279
+ // mandatory inlining when `-enable-strip-ownership-after-serialization` is
280
+ // true and ownership model eliminator is not run after differentiation.
281
+ auto *thunk = fb.getOrCreateSharedFunction (
282
+ loc, name, thunkDeclType, IsBare, IsNotTransparent, IsSerialized,
283
+ ProfileCounter (), IsReabstractionThunk, IsNotDynamic);
284
+ if (!thunk->empty ())
285
+ return thunk;
286
+
287
+ thunk->setGenericEnvironment (genericEnv);
288
+ thunk->setOwnershipEliminated ();
289
+ auto *entry = thunk->createBasicBlock ();
290
+ SILBuilder builder (entry);
291
+ createEntryArguments (thunk);
292
+
293
+ SILFunctionConventions fromConv (fromType, module );
294
+ SILFunctionConventions toConv (toType, module );
295
+ assert (toConv.useLoweredAddresses ());
296
+
297
+ auto *fnArg = thunk->getArgumentsWithoutIndirectResults ().back ();
298
+
299
+ SmallVector<SILValue, 4 > arguments;
300
+ auto toArgIter = thunk->getArguments ().begin ();
301
+ auto useNextArgument = [&]() { arguments.push_back (*toArgIter++); };
302
+
303
+ SmallVector<AllocStackInst *, 4 > localAllocations;
304
+ auto createAllocStack = [&](SILType type) {
305
+ auto *alloc = builder.createAllocStack (loc, type);
306
+ localAllocations.push_back (alloc);
307
+ return alloc;
308
+ };
309
+
310
+ // Handle indirect results.
311
+ assert (fromType->getNumResults () == toType->getNumResults ());
312
+ for (unsigned resIdx : range (toType->getNumResults ())) {
313
+ auto fromRes = fromConv.getResults ()[resIdx];
314
+ auto toRes = toConv.getResults ()[resIdx];
315
+ // No abstraction mismatch.
316
+ if (fromRes.isFormalIndirect () == toRes.isFormalIndirect ()) {
317
+ // If result types are indirect, directly pass as next argument.
318
+ if (toRes.isFormalIndirect ())
319
+ useNextArgument ();
320
+ continue ;
321
+ }
322
+ // Convert indirect result to direct result.
323
+ if (fromRes.isFormalIndirect ()) {
324
+ SILType resultTy = fromConv.getSILType (fromRes);
325
+ assert (resultTy.isAddress ());
326
+ auto *indRes = createAllocStack (resultTy);
327
+ arguments.push_back (indRes);
328
+ continue ;
329
+ }
330
+ // Convert direct result to indirect result.
331
+ // Increment thunk argument iterator; reabstraction handled later.
332
+ toArgIter++;
333
+ }
334
+
335
+ // Reabstract parameters.
336
+ assert (toType->getNumParameters () == fromType->getNumParameters ());
337
+ for (unsigned paramIdx : range (toType->getNumParameters ())) {
338
+ auto fromParam = fromConv.getParameters ()[paramIdx];
339
+ auto toParam = toConv.getParameters ()[paramIdx];
340
+ // No abstraction mismatch. Directly use next argument.
341
+ if (fromParam.isFormalIndirect () == toParam.isFormalIndirect ()) {
342
+ useNextArgument ();
343
+ continue ;
344
+ }
345
+ // Convert indirect parameter to direct parameter.
346
+ if (fromParam.isFormalIndirect ()) {
347
+ auto paramTy = fromConv.getSILType (fromType->getParameters ()[paramIdx]);
348
+ if (!paramTy.hasArchetype ())
349
+ paramTy = thunk->mapTypeIntoContext (paramTy);
350
+ assert (paramTy.isAddress ());
351
+ auto *toArg = *toArgIter++;
352
+ auto *buf = createAllocStack (toArg->getType ());
353
+ builder.createStore (loc, toArg, buf,
354
+ StoreOwnershipQualifier::Unqualified);
355
+ arguments.push_back (buf);
356
+ continue ;
357
+ }
358
+ // Convert direct parameter to indirect parameter.
359
+ assert (toParam.isFormalIndirect ());
360
+ auto *toArg = *toArgIter++;
361
+ auto *load =
362
+ builder.createLoad (loc, toArg, LoadOwnershipQualifier::Unqualified);
363
+ arguments.push_back (load);
364
+ }
365
+
366
+ auto *apply = builder.createApply (loc, fnArg, SubstitutionMap (), arguments,
367
+ /* isNonThrowing*/ false );
368
+
369
+ // Get return elements.
370
+ SmallVector<SILValue, 4 > results;
371
+ // Extract all direct results.
372
+ SmallVector<SILValue, 4 > directResults;
373
+ extractAllElements (apply, builder, directResults);
374
+
375
+ auto fromDirResultsIter = directResults.begin ();
376
+ auto fromIndResultsIter = apply->getIndirectSILResults ().begin ();
377
+ auto toIndResultsIter = thunk->getIndirectResults ().begin ();
378
+ // Reabstract results.
379
+ for (unsigned resIdx : range (toType->getNumResults ())) {
380
+ auto fromRes = fromConv.getResults ()[resIdx];
381
+ auto toRes = toConv.getResults ()[resIdx];
382
+ // No abstraction mismatch.
383
+ if (fromRes.isFormalIndirect () == toRes.isFormalIndirect ()) {
384
+ // If result types are direct, add call result as direct thunk result.
385
+ if (toRes.isFormalDirect ())
386
+ results.push_back (*fromDirResultsIter++);
387
+ // If result types are indirect, increment indirect result iterators.
388
+ else {
389
+ ++fromIndResultsIter;
390
+ ++toIndResultsIter;
391
+ }
392
+ continue ;
393
+ }
394
+ // Load direct results from indirect results.
395
+ if (fromRes.isFormalIndirect ()) {
396
+ auto indRes = *fromIndResultsIter++;
397
+ auto *load =
398
+ builder.createLoad (loc, indRes, LoadOwnershipQualifier::Unqualified);
399
+ results.push_back (load);
400
+ continue ;
401
+ }
402
+ // Store direct results to indirect results.
403
+ assert (toRes.isFormalIndirect ());
404
+ SILType resultTy = toConv.getSILType (toRes);
405
+ assert (resultTy.isAddress ());
406
+ auto indRes = *toIndResultsIter++;
407
+ builder.createStore (loc, *fromDirResultsIter++, indRes,
408
+ StoreOwnershipQualifier::Unqualified);
409
+ }
410
+ auto retVal = joinElements (results, builder, loc);
411
+
412
+ // Deallocate local allocations.
413
+ for (auto *alloc : llvm::reverse (localAllocations))
414
+ builder.createDeallocStack (loc, alloc);
415
+
416
+ // Create return.
417
+ builder.createReturn (loc, retVal);
418
+
419
+ LLVM_DEBUG (auto &s = getADDebugStream () << " Created reabstraction thunk.\n " ;
420
+ s << " From type: " << fromType << ' \n ' ;
421
+ s << " To type: " << toType << ' \n ' ; s << ' \n '
422
+ << *thunk);
423
+
424
+ return thunk;
425
+ }
426
+
427
+ SILValue reabstractFunction (
428
+ SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
429
+ SILValue fn, CanSILFunctionType toType,
430
+ std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
431
+ auto &module = *fn->getModule ();
432
+ auto fromType = fn->getType ().getAs <SILFunctionType>();
433
+ auto unsubstFromType = fromType->getUnsubstitutedType (module );
434
+ auto unsubstToType = toType->getUnsubstitutedType (module );
435
+
436
+ auto *thunk = getOrCreateReabstractionThunk (fb, module , loc,
437
+ /* caller*/ fn->getFunction (),
438
+ unsubstFromType, unsubstToType);
439
+ auto *thunkRef = builder.createFunctionRef (loc, thunk);
440
+
441
+ if (fromType != unsubstFromType)
442
+ fn = builder.createConvertFunction (
443
+ loc, fn, SILType::getPrimitiveObjectType (unsubstFromType),
444
+ /* withoutActuallyEscaping*/ false );
445
+
446
+ fn = builder.createPartialApply (
447
+ loc, thunkRef, remapSubstitutions (thunk->getForwardingSubstitutionMap ()),
448
+ {fn}, fromType->getCalleeConvention ());
449
+
450
+ if (toType != unsubstToType)
451
+ fn = builder.createConvertFunction (loc, fn,
452
+ SILType::getPrimitiveObjectType (toType),
453
+ /* withoutActuallyEscaping*/ false );
454
+
455
+ return fn;
456
+ }
457
+
252
458
std::pair<SILFunction *, SubstitutionMap>
253
459
getOrCreateSubsetParametersThunkForLinearMap (
254
460
SILOptFunctionBuilder &fb, SILFunction *parentThunk,
0 commit comments