Skip to content

Commit 55ac2c0

Browse files
committed
[AutoDiff upstream] Add common differentiation thunking utilities.
1 parent 8081482 commit 55ac2c0

File tree

2 files changed

+221
-0
lines changed
  • include/swift/SILOptimizer/Utils/Differentiation
  • lib/SILOptimizer/Utils/Differentiation

2 files changed

+221
-0
lines changed

include/swift/SILOptimizer/Utils/Differentiation/Thunk.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,21 @@ CanSILFunctionType buildThunkType(SILFunction *fn,
8181
bool withoutActuallyEscaping,
8282
DifferentiationThunkKind thunkKind);
8383

84+
/// Get or create a reabstraction thunk from `fromType` to `toType`, to be
85+
/// called in `caller`.
86+
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
87+
SILModule &module, SILLocation loc,
88+
SILFunction *caller,
89+
CanSILFunctionType fromType,
90+
CanSILFunctionType toType);
91+
92+
/// Reabstracts the given function-typed value `fn` to the target type `toType`.
93+
/// Remaps substitutions using `remapSubstitutions`.
94+
SILValue reabstractFunction(
95+
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
96+
SILValue fn, CanSILFunctionType toType,
97+
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);
98+
8499
/// Get or create a derivative function parameter index subset thunk from
85100
/// `actualIndices` to `desiredIndices` for the given associated function
86101
/// value and original function operand. Returns a pair of the parameter

lib/SILOptimizer/Utils/Differentiation/Thunk.cpp

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,212 @@ CanSILFunctionType buildThunkType(SILFunction *fn,
249249
fn->getASTContext());
250250
}
251251

252+
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
253+
SILModule &module, SILLocation loc,
254+
SILFunction *caller,
255+
CanSILFunctionType fromType,
256+
CanSILFunctionType toType) {
257+
assert(!fromType->getCombinedSubstitutions());
258+
assert(!toType->getCombinedSubstitutions());
259+
260+
SubstitutionMap interfaceSubs;
261+
GenericEnvironment *genericEnv = nullptr;
262+
auto thunkType =
263+
buildThunkType(caller, fromType, toType, genericEnv, interfaceSubs,
264+
/*withoutActuallyEscaping*/ false,
265+
DifferentiationThunkKind::Reabstraction);
266+
auto thunkDeclType =
267+
thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false));
268+
269+
auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType();
270+
auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType();
271+
272+
Mangle::ASTMangler mangler;
273+
std::string name = mangler.mangleReabstractionThunkHelper(
274+
thunkType, fromInterfaceType, toInterfaceType, Type(),
275+
module.getSwiftModule());
276+
277+
// FIXME(TF-989): Mark reabstraction thunks as transparent. This requires
278+
// generating ossa reabstraction thunks so that they can be inlined during
279+
// mandatory inlining when `-enable-strip-ownership-after-serialization` is
280+
// true and ownership model eliminator is not run after differentiation.
281+
auto *thunk = fb.getOrCreateSharedFunction(
282+
loc, name, thunkDeclType, IsBare, IsNotTransparent, IsSerialized,
283+
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
284+
if (!thunk->empty())
285+
return thunk;
286+
287+
thunk->setGenericEnvironment(genericEnv);
288+
thunk->setOwnershipEliminated();
289+
auto *entry = thunk->createBasicBlock();
290+
SILBuilder builder(entry);
291+
createEntryArguments(thunk);
292+
293+
SILFunctionConventions fromConv(fromType, module);
294+
SILFunctionConventions toConv(toType, module);
295+
assert(toConv.useLoweredAddresses());
296+
297+
auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
298+
299+
SmallVector<SILValue, 4> arguments;
300+
auto toArgIter = thunk->getArguments().begin();
301+
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
302+
303+
SmallVector<AllocStackInst *, 4> localAllocations;
304+
auto createAllocStack = [&](SILType type) {
305+
auto *alloc = builder.createAllocStack(loc, type);
306+
localAllocations.push_back(alloc);
307+
return alloc;
308+
};
309+
310+
// Handle indirect results.
311+
assert(fromType->getNumResults() == toType->getNumResults());
312+
for (unsigned resIdx : range(toType->getNumResults())) {
313+
auto fromRes = fromConv.getResults()[resIdx];
314+
auto toRes = toConv.getResults()[resIdx];
315+
// No abstraction mismatch.
316+
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
317+
// If result types are indirect, directly pass as next argument.
318+
if (toRes.isFormalIndirect())
319+
useNextArgument();
320+
continue;
321+
}
322+
// Convert indirect result to direct result.
323+
if (fromRes.isFormalIndirect()) {
324+
SILType resultTy = fromConv.getSILType(fromRes);
325+
assert(resultTy.isAddress());
326+
auto *indRes = createAllocStack(resultTy);
327+
arguments.push_back(indRes);
328+
continue;
329+
}
330+
// Convert direct result to indirect result.
331+
// Increment thunk argument iterator; reabstraction handled later.
332+
toArgIter++;
333+
}
334+
335+
// Reabstract parameters.
336+
assert(toType->getNumParameters() == fromType->getNumParameters());
337+
for (unsigned paramIdx : range(toType->getNumParameters())) {
338+
auto fromParam = fromConv.getParameters()[paramIdx];
339+
auto toParam = toConv.getParameters()[paramIdx];
340+
// No abstraction mismatch. Directly use next argument.
341+
if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) {
342+
useNextArgument();
343+
continue;
344+
}
345+
// Convert indirect parameter to direct parameter.
346+
if (fromParam.isFormalIndirect()) {
347+
auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx]);
348+
if (!paramTy.hasArchetype())
349+
paramTy = thunk->mapTypeIntoContext(paramTy);
350+
assert(paramTy.isAddress());
351+
auto *toArg = *toArgIter++;
352+
auto *buf = createAllocStack(toArg->getType());
353+
builder.createStore(loc, toArg, buf,
354+
StoreOwnershipQualifier::Unqualified);
355+
arguments.push_back(buf);
356+
continue;
357+
}
358+
// Convert direct parameter to indirect parameter.
359+
assert(toParam.isFormalIndirect());
360+
auto *toArg = *toArgIter++;
361+
auto *load =
362+
builder.createLoad(loc, toArg, LoadOwnershipQualifier::Unqualified);
363+
arguments.push_back(load);
364+
}
365+
366+
auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments,
367+
/*isNonThrowing*/ false);
368+
369+
// Get return elements.
370+
SmallVector<SILValue, 4> results;
371+
// Extract all direct results.
372+
SmallVector<SILValue, 4> directResults;
373+
extractAllElements(apply, builder, directResults);
374+
375+
auto fromDirResultsIter = directResults.begin();
376+
auto fromIndResultsIter = apply->getIndirectSILResults().begin();
377+
auto toIndResultsIter = thunk->getIndirectResults().begin();
378+
// Reabstract results.
379+
for (unsigned resIdx : range(toType->getNumResults())) {
380+
auto fromRes = fromConv.getResults()[resIdx];
381+
auto toRes = toConv.getResults()[resIdx];
382+
// No abstraction mismatch.
383+
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
384+
// If result types are direct, add call result as direct thunk result.
385+
if (toRes.isFormalDirect())
386+
results.push_back(*fromDirResultsIter++);
387+
// If result types are indirect, increment indirect result iterators.
388+
else {
389+
++fromIndResultsIter;
390+
++toIndResultsIter;
391+
}
392+
continue;
393+
}
394+
// Load direct results from indirect results.
395+
if (fromRes.isFormalIndirect()) {
396+
auto indRes = *fromIndResultsIter++;
397+
auto *load =
398+
builder.createLoad(loc, indRes, LoadOwnershipQualifier::Unqualified);
399+
results.push_back(load);
400+
continue;
401+
}
402+
// Store direct results to indirect results.
403+
assert(toRes.isFormalIndirect());
404+
SILType resultTy = toConv.getSILType(toRes);
405+
assert(resultTy.isAddress());
406+
auto indRes = *toIndResultsIter++;
407+
builder.createStore(loc, *fromDirResultsIter++, indRes,
408+
StoreOwnershipQualifier::Unqualified);
409+
}
410+
auto retVal = joinElements(results, builder, loc);
411+
412+
// Deallocate local allocations.
413+
for (auto *alloc : llvm::reverse(localAllocations))
414+
builder.createDeallocStack(loc, alloc);
415+
416+
// Create return.
417+
builder.createReturn(loc, retVal);
418+
419+
LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n";
420+
s << " From type: " << fromType << '\n';
421+
s << " To type: " << toType << '\n'; s << '\n'
422+
<< *thunk);
423+
424+
return thunk;
425+
}
426+
427+
SILValue reabstractFunction(
428+
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
429+
SILValue fn, CanSILFunctionType toType,
430+
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
431+
auto &module = *fn->getModule();
432+
auto fromType = fn->getType().getAs<SILFunctionType>();
433+
auto unsubstFromType = fromType->getUnsubstitutedType(module);
434+
auto unsubstToType = toType->getUnsubstitutedType(module);
435+
436+
auto *thunk = getOrCreateReabstractionThunk(fb, module, loc,
437+
/*caller*/ fn->getFunction(),
438+
unsubstFromType, unsubstToType);
439+
auto *thunkRef = builder.createFunctionRef(loc, thunk);
440+
441+
if (fromType != unsubstFromType)
442+
fn = builder.createConvertFunction(
443+
loc, fn, SILType::getPrimitiveObjectType(unsubstFromType),
444+
/*withoutActuallyEscaping*/ false);
445+
446+
fn = builder.createPartialApply(
447+
loc, thunkRef, remapSubstitutions(thunk->getForwardingSubstitutionMap()),
448+
{fn}, fromType->getCalleeConvention());
449+
450+
if (toType != unsubstToType)
451+
fn = builder.createConvertFunction(loc, fn,
452+
SILType::getPrimitiveObjectType(toType),
453+
/*withoutActuallyEscaping*/ false);
454+
455+
return fn;
456+
}
457+
252458
std::pair<SILFunction *, SubstitutionMap>
253459
getOrCreateSubsetParametersThunkForLinearMap(
254460
SILOptFunctionBuilder &fb, SILFunction *parentThunk,

0 commit comments

Comments
 (0)