Skip to content

Commit a5e8381

Browse files
asldan-zheng
andauthored
Reapply "Implement several peephole optimizations to unblock further optimizations of autodiff code" with correctness fix (#62012)
* Implement several peephole optimizations to unblock further optimizations of autodiff code 1. Simplify differentiable_function_extract of differentiable_function. Before: %x = differentiable_function(%orig, %jvp, %vjp) %y = differentiable_function_extract [original] %x After: %y = %orig 2. Push conversion instructions inside of differentiable_function. This unblocks inlining and specialization. Before: %x = differentiable_function(%orig, %jvp, %vjp) %y = convert_escape_to_noescape %x After: %orig' = convert_escape_to_noescape %orig %jvp' = convert_escape_to_noescape %jvp %vjp' = convert_escape_to_noescape %vjp %y = differentiable_function(%orig', %jvp', %vjp') 3. Another peephole is needed for reordering function conversion instructions to enable full inlining: (convert_escape_to_noescape (convert_function (thin_to_thick_function x))) => (convert_escape_to_noescape (thin_to_thick_function (convert_function x))) Co-authored-by: Dan Zheng <[email protected]>
1 parent f11ba15 commit a5e8381

File tree

10 files changed

+514
-78
lines changed

10 files changed

+514
-78
lines changed

include/swift/AST/Types.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4699,6 +4699,17 @@ class SILFunctionType final
46994699
/// differentiability from all parameters.
47004700
CanSILFunctionType getWithoutDifferentiability();
47014701

4702+
/// Given that `this` is a `@differentiable` function type, returns the type
4703+
/// of the given `@differentiable` function type component.
4704+
CanSILFunctionType getDifferentiableComponentType(
4705+
NormalDifferentiableFunctionTypeComponent component, SILModule &module);
4706+
4707+
/// Given that `this` is a `@differentiable(linear)` function type, returns
4708+
/// the type of the given `@differentiable(linear)` function type component.
4709+
CanSILFunctionType
4710+
getLinearComponentType(LinearDifferentiableFunctionTypeComponent component,
4711+
SILModule &module);
4712+
47024713
/// Returns the type of the derivative function for the given parameter
47034714
/// indices, result indices, derivative function kind, derivative function
47044715
/// generic signature (optional), and other auxiliary parameters.

include/swift/SIL/SILInstruction.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9628,6 +9628,34 @@ class DifferentiableFunctionInst final
96289628
}
96299629
llvm_unreachable("invalid derivative kind");
96309630
}
9631+
9632+
9633+
/// Returns true iff the operand corresponding to the given extractee kind
9634+
/// exists.
9635+
bool hasExtractee(NormalDifferentiableFunctionTypeComponent extractee) const {
9636+
switch (extractee) {
9637+
case NormalDifferentiableFunctionTypeComponent::Original:
9638+
return true;
9639+
case NormalDifferentiableFunctionTypeComponent::JVP:
9640+
case NormalDifferentiableFunctionTypeComponent::VJP:
9641+
return hasDerivativeFunctions();
9642+
}
9643+
llvm_unreachable("invalid extractee kind");
9644+
}
9645+
9646+
/// Returns the operand corresponding to the given extractee kind.
9647+
SILValue
9648+
getExtractee(NormalDifferentiableFunctionTypeComponent extractee) const {
9649+
switch (extractee) {
9650+
case NormalDifferentiableFunctionTypeComponent::Original:
9651+
return getOriginalFunction();
9652+
case NormalDifferentiableFunctionTypeComponent::JVP:
9653+
return getJVPFunction();
9654+
case NormalDifferentiableFunctionTypeComponent::VJP:
9655+
return getVJPFunction();
9656+
}
9657+
llvm_unreachable("invalid extractee kind");
9658+
}
96319659
};
96329660

96339661
/// LinearFunctionInst - given a function, its derivative and transpose functions,
@@ -9668,6 +9696,31 @@ class LinearFunctionInst final
96689696
assert(HasTransposeFunction);
96699697
return getOperand(1);
96709698
}
9699+
9700+
9701+
/// Returns true iff the operand corresponding to the given extractee kind
9702+
/// exists.
9703+
bool hasExtractee(LinearDifferentiableFunctionTypeComponent extractee) const {
9704+
switch (extractee) {
9705+
case LinearDifferentiableFunctionTypeComponent::Original:
9706+
return true;
9707+
case LinearDifferentiableFunctionTypeComponent::Transpose:
9708+
return hasTransposeFunction();
9709+
}
9710+
llvm_unreachable("invalid extractee kind");
9711+
}
9712+
9713+
/// Returns the operand corresponding to the given extractee kind.
9714+
SILValue
9715+
getExtractee(LinearDifferentiableFunctionTypeComponent extractee) const {
9716+
switch (extractee) {
9717+
case LinearDifferentiableFunctionTypeComponent::Original:
9718+
return getOriginalFunction();
9719+
case LinearDifferentiableFunctionTypeComponent::Transpose:
9720+
return getTransposeFunction();
9721+
}
9722+
llvm_unreachable("invalid extractee kind");
9723+
}
96719724
};
96729725

96739726
/// DifferentiableFunctionExtractInst - extracts either the original or

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,35 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
263263
return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices);
264264
}
265265

266+
CanSILFunctionType SILFunctionType::getDifferentiableComponentType(
267+
NormalDifferentiableFunctionTypeComponent component, SILModule &module) {
268+
assert(getDifferentiabilityKind() == DifferentiabilityKind::Reverse &&
269+
"Must be a `@differentiable(reverse)` function");
270+
auto originalFnTy = getWithoutDifferentiability();
271+
if (auto derivativeKind = component.getAsDerivativeFunctionKind()) {
272+
return originalFnTy->getAutoDiffDerivativeFunctionType(
273+
getDifferentiabilityParameterIndices(),
274+
getDifferentiabilityResultIndices(), *derivativeKind, module.Types,
275+
LookUpConformanceInModule(module.getSwiftModule()));
276+
}
277+
return originalFnTy;
278+
}
279+
280+
CanSILFunctionType SILFunctionType::getLinearComponentType(
281+
LinearDifferentiableFunctionTypeComponent component, SILModule &module) {
282+
assert(getDifferentiabilityKind() == DifferentiabilityKind::Linear &&
283+
"Must be a `@differentiable(linear)` function");
284+
auto originalFnTy = getWithoutDifferentiability();
285+
switch (component) {
286+
case LinearDifferentiableFunctionTypeComponent::Original:
287+
return originalFnTy;
288+
case LinearDifferentiableFunctionTypeComponent::Transpose:
289+
return originalFnTy->getAutoDiffTransposeFunctionType(
290+
getDifferentiabilityParameterIndices(), module.Types,
291+
LookUpConformanceInModule(module.getSwiftModule()));
292+
}
293+
}
294+
266295
CanSILFunctionType
267296
SILFunctionType::getWithDifferentiability(DifferentiabilityKind kind,
268297
IndexSubset *parameterIndices,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,6 @@ using llvm::SmallSet;
6262
static llvm::cl::opt<bool> EnableExperimentalLinearMapTransposition(
6363
"enable-experimental-linear-map-transposition", llvm::cl::init(false));
6464

65-
/// This flag is used to disable `differentiable_function_extract` instruction
66-
/// folding for SIL testing purposes.
67-
static llvm::cl::opt<bool> SkipFoldingDifferentiableFunctionExtraction(
68-
"differentiation-skip-folding-differentiable-function-extraction",
69-
llvm::cl::init(true));
70-
7165
//===----------------------------------------------------------------------===//
7266
// Helpers
7367
//===----------------------------------------------------------------------===//
@@ -127,17 +121,6 @@ class DifferentiationTransformer {
127121
/// Process the given `linear_function` instruction, filling in the missing
128122
/// transpose function if necessary.
129123
bool processLinearFunctionInst(LinearFunctionInst *lfi);
130-
131-
/// Fold `differentiable_function_extract` users of the given
132-
/// `differentiable_function` instruction, directly replacing them with
133-
/// `differentiable_function` instruction operands. If the
134-
/// `differentiable_function` instruction has no remaining uses, delete the
135-
/// instruction itself after folding.
136-
///
137-
/// Folding can be disabled by the
138-
/// `SkipFoldingDifferentiableFunctionExtraction` flag for SIL testing
139-
/// purposes.
140-
void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source);
141124
};
142125

143126
} // end anonymous namespace
@@ -1239,51 +1222,6 @@ SILValue DifferentiationTransformer::promoteToLinearFunction(
12391222
return newLinearFn;
12401223
}
12411224

1242-
/// Fold `differentiable_function_extract` users of the given
1243-
/// `differentiable_function` instruction, directly replacing them with
1244-
/// `differentiable_function` instruction operands. If the
1245-
/// `differentiable_function` instruction has no remaining uses, delete the
1246-
/// instruction itself after folding.
1247-
///
1248-
/// Folding can be disabled by the `SkipFoldingDifferentiableFunctionExtraction`
1249-
/// flag for SIL testing purposes.
1250-
// FIXME: This function is not correctly detecting the foldable pattern and
1251-
// needs to be rewritten.
1252-
void DifferentiationTransformer::foldDifferentiableFunctionExtraction(
1253-
DifferentiableFunctionInst *source) {
1254-
// Iterate through all `differentiable_function` instruction uses.
1255-
for (auto use : source->getUses()) {
1256-
auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(use->getUser());
1257-
// If user is not an `differentiable_function_extract` instruction, set flag
1258-
// to false.
1259-
if (!dfei)
1260-
continue;
1261-
// Fold original function extractors.
1262-
if (dfei->getExtractee() ==
1263-
NormalDifferentiableFunctionTypeComponent::Original) {
1264-
auto originalFnValue = source->getOriginalFunction();
1265-
dfei->replaceAllUsesWith(originalFnValue);
1266-
dfei->eraseFromParent();
1267-
continue;
1268-
}
1269-
// Fold derivative function extractors.
1270-
auto derivativeFnValue =
1271-
source->getDerivativeFunction(dfei->getDerivativeFunctionKind());
1272-
dfei->replaceAllUsesWith(derivativeFnValue);
1273-
dfei->eraseFromParent();
1274-
}
1275-
// If the `differentiable_function` instruction has no remaining uses, erase
1276-
// it.
1277-
if (isInstructionTriviallyDead(source)) {
1278-
SILBuilder builder(source);
1279-
builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction());
1280-
builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction());
1281-
source->eraseFromParent();
1282-
}
1283-
// Mark `source` as processed so that it won't be reprocessed after deletion.
1284-
context.markDifferentiableFunctionInstAsProcessed(source);
1285-
}
1286-
12871225
bool DifferentiationTransformer::processDifferentiableFunctionInst(
12881226
DifferentiableFunctionInst *dfi) {
12891227
PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`",
@@ -1312,14 +1250,6 @@ bool DifferentiationTransformer::processDifferentiableFunctionInst(
13121250
// Destroy the original operand.
13131251
builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction());
13141252
dfi->eraseFromParent();
1315-
// If the promoted `@differentiable` function-typed value is an
1316-
// `differentiable_function` instruction, fold
1317-
// `differentiable_function_extract` instructions. If
1318-
// `differentiable_function_extract` folding is disabled, return.
1319-
if (!SkipFoldingDifferentiableFunctionExtraction)
1320-
if (auto *newDFI =
1321-
dyn_cast<DifferentiableFunctionInst>(differentiableFnValue))
1322-
foldDifferentiableFunctionExtraction(newDFI);
13231253
transform.invalidateAnalysis(parent,
13241254
SILAnalysis::InvalidationKind::FunctionBody);
13251255
return false;

lib/SILOptimizer/SILCombiner/SILCombiner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,9 @@ class SILCombiner :
294294
SILInstruction *visitConvertFunctionInst(ConvertFunctionInst *CFI);
295295
SILInstruction *
296296
visitConvertEscapeToNoEscapeInst(ConvertEscapeToNoEscapeInst *Cvt);
297-
297+
SILInstruction *
298+
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *DFEI);
299+
298300
SILInstruction *legacyVisitGlobalValueInst(GlobalValueInst *globalValue);
299301

300302
#define PASS(ID, TAG, DESCRIPTION)

lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,16 +1065,94 @@ visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI) {
10651065

10661066
SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
10671067
ConvertEscapeToNoEscapeInst *Cvt) {
1068-
auto *OrigThinToThick =
1069-
dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted());
1070-
if (!OrigThinToThick)
1071-
return nullptr;
1072-
auto origFunType = OrigThinToThick->getType().getAs<SILFunctionType>();
1073-
auto NewTy = origFunType->getWithExtInfo(origFunType->getExtInfo().withNoEscape(true));
1068+
// Rewrite conversion of `convert_function` of `thin_to_thick_function` as
1069+
// conversion of `thin_to_thick_function` of `convert_function`.
1070+
//
1071+
// (convert_escape_to_noescape (convert_function (thin_to_thick_function x)))
1072+
// =>
1073+
// (convert_escape_to_noescape (thin_to_thick_function (convert_function x)))
1074+
//
1075+
// This unblocks the `thin_to_thick_function` peephole optimization below.
1076+
if (auto *CFI = dyn_cast<ConvertFunctionInst>(Cvt->getConverted())) {
1077+
if (CFI->getSingleUse()) {
1078+
if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(CFI->getConverted())) {
1079+
if (TTTFI->getSingleUse()) {
1080+
auto convertedThickType = CFI->getType().castTo<SILFunctionType>();
1081+
auto convertedThinType = convertedThickType->getWithRepresentation(
1082+
SILFunctionTypeRepresentation::Thin);
1083+
auto *newCFI = Builder.createConvertFunction(
1084+
CFI->getLoc(), TTTFI->getConverted(),
1085+
SILType::getPrimitiveObjectType(convertedThinType),
1086+
CFI->withoutActuallyEscaping());
1087+
auto *newTTTFI = Builder.createThinToThickFunction(
1088+
TTTFI->getLoc(), newCFI, CFI->getType());
1089+
replaceInstUsesWith(*CFI, newTTTFI);
1090+
}
1091+
}
1092+
}
1093+
}
1094+
1095+
// Rewrite conversion of `thin_to_thick_function` as `thin_to_thick_function`
1096+
// with a noescape function type.
1097+
//
1098+
// (convert_escape_to_noescape (thin_to_thick_function x)) =>
1099+
// (thin_to_thick_function [noescape] x)
1100+
if (auto *OrigThinToThick = dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted())) {
1101+
auto origFunType = OrigThinToThick->getType().getAs<SILFunctionType>();
1102+
auto NewTy = origFunType->getWithExtInfo(origFunType->getExtInfo().withNoEscape(true));
10741103

1075-
return Builder.createThinToThickFunction(
1104+
return Builder.createThinToThickFunction(
10761105
OrigThinToThick->getLoc(), OrigThinToThick->getOperand(),
10771106
SILType::getPrimitiveObjectType(NewTy));
1107+
}
1108+
1109+
// Push conversion instructions inside `differentiable_function`. This
1110+
// unblocks more optimizations.
1111+
//
1112+
// Before:
1113+
// %x = differentiable_function(%orig, %jvp, %vjp)
1114+
// %y = convert_escape_to_noescape %x
1115+
//
1116+
// After:
1117+
// %orig' = convert_escape_to_noescape %orig
1118+
// %jvp' = convert_escape_to_noescape %jvp
1119+
// %vjp' = convert_escape_to_noescape %vjp
1120+
// %y = differentiable_function(%orig', %jvp', %vjp')
1121+
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted())) {
1122+
auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1123+
if (!DFI->hasExtractee(extractee))
1124+
return SILValue();
1125+
1126+
auto operand = DFI->getExtractee(extractee);
1127+
auto fnType = operand->getType().castTo<SILFunctionType>();
1128+
auto noEscapeFnType =
1129+
fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape());
1130+
auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType);
1131+
return Builder.createConvertEscapeToNoEscape(
1132+
operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0);
1133+
};
1134+
1135+
SILValue originalNoEscape =
1136+
createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original);
1137+
SILValue convertedJVP = createConvertEscapeToNoEscape(
1138+
NormalDifferentiableFunctionTypeComponent::JVP);
1139+
SILValue convertedVJP = createConvertEscapeToNoEscape(
1140+
NormalDifferentiableFunctionTypeComponent::VJP);
1141+
1142+
Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1143+
if (convertedJVP && convertedVJP)
1144+
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
1145+
1146+
auto *newDFI = Builder.createDifferentiableFunction(
1147+
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1148+
originalNoEscape, derivativeFunctions);
1149+
assert(newDFI->getType() == Cvt->getType() &&
1150+
"New `@differentiable` function instruction should have same type "
1151+
"as the old `convert_escape_to_no_escape` instruction");
1152+
return newDFI;
1153+
}
1154+
1155+
return nullptr;
10781156
}
10791157

10801158
SILInstruction *
@@ -1207,6 +1285,54 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) {
12071285
return std::move(folder).optimizeWithSetValue(subCFI->getConverted());
12081286
}
12091287

1288+
// Push conversion instructions inside `differentiable_function`. This
1289+
// unblocks more optimizations.
1290+
//
1291+
// Before:
1292+
// %x = differentiable_function(%orig, %jvp, %vjp)
1293+
// %y = convert_function %x
1294+
//
1295+
// After:
1296+
// %orig' = convert_function %orig
1297+
// %jvp' = convert_function %jvp
1298+
// %vjp' = convert_function %vjp
1299+
// %y = differentiable_function(%orig', %jvp', %vjp')
1300+
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(cfi->getConverted())) {
1301+
auto createConvertFunctionOfComponent =
1302+
[&](NormalDifferentiableFunctionTypeComponent extractee) {
1303+
if (!DFI->hasExtractee(extractee))
1304+
return SILValue();
1305+
1306+
auto operand = DFI->getExtractee(extractee);
1307+
auto convertInstType =
1308+
cfi->getType().castTo<SILFunctionType>();
1309+
auto convertedComponentFnType =
1310+
convertInstType->getDifferentiableComponentType(
1311+
extractee, Builder.getModule());
1312+
auto convertedComponentType =
1313+
SILType::getPrimitiveObjectType(convertedComponentFnType);
1314+
return Builder.createConvertFunction(
1315+
operand.getLoc(), operand, convertedComponentType,
1316+
cfi->withoutActuallyEscaping())->getResult(0);
1317+
};
1318+
SILValue convertedOriginal = createConvertFunctionOfComponent(
1319+
NormalDifferentiableFunctionTypeComponent::Original);
1320+
SILValue convertedJVP = createConvertFunctionOfComponent(
1321+
NormalDifferentiableFunctionTypeComponent::JVP);
1322+
SILValue convertedVJP = createConvertFunctionOfComponent(
1323+
NormalDifferentiableFunctionTypeComponent::VJP);
1324+
Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1325+
if (convertedJVP && convertedVJP)
1326+
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
1327+
auto *newDFI = Builder.createDifferentiableFunction(
1328+
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1329+
convertedOriginal, derivativeFunctions);
1330+
assert(newDFI->getType() == cfi->getType() &&
1331+
"New `@differentiable` function instruction should have same type "
1332+
"as the old `convert_function` instruction");
1333+
return newDFI;
1334+
}
1335+
12101336
// Replace a convert_function that only has refcounting uses with its
12111337
// operand.
12121338
tryEliminateOnlyOwnershipUsedForwardingInst(cfi, getInstModCallbacks());

0 commit comments

Comments
 (0)