Skip to content

Commit 5e20112

Browse files
committed
Revert "Implement several peephole optimizations to unblock further optimizations of autodiff code (#60520)"
This reverts commit 2f5492f.
1 parent e549392 commit 5e20112

File tree

10 files changed

+87
-467
lines changed

10 files changed

+87
-467
lines changed

include/swift/AST/Types.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4683,17 +4683,6 @@ class SILFunctionType final
46834683
/// differentiability from all parameters.
46844684
CanSILFunctionType getWithoutDifferentiability();
46854685

4686-
/// Given that `this` is a `@differentiable` function type, returns the type
4687-
/// of the given `@differentiable` function type component.
4688-
CanSILFunctionType getDifferentiableComponentType(
4689-
NormalDifferentiableFunctionTypeComponent component, SILModule &module);
4690-
4691-
/// Given that `this` is a `@differentiable(linear)` function type, returns
4692-
/// the type of the given `@differentiable(linear)` function type component.
4693-
CanSILFunctionType
4694-
getLinearComponentType(LinearDifferentiableFunctionTypeComponent component,
4695-
SILModule &module);
4696-
46974686
/// Returns the type of the derivative function for the given parameter
46984687
/// indices, result indices, derivative function kind, derivative function
46994688
/// generic signature (optional), and other auxiliary parameters.

include/swift/SIL/SILInstruction.h

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9628,34 +9628,6 @@ class DifferentiableFunctionInst final
96289628
}
96299629
llvm_unreachable("invalid derivative kind");
96309630
}
9631-
9632-
9633-
/// Returns true iff the operand corresponding to the given extractee kind
9634-
/// exists.
9635-
bool hasExtractee(NormalDifferentiableFunctionTypeComponent extractee) const {
9636-
switch (extractee) {
9637-
case NormalDifferentiableFunctionTypeComponent::Original:
9638-
return true;
9639-
case NormalDifferentiableFunctionTypeComponent::JVP:
9640-
case NormalDifferentiableFunctionTypeComponent::VJP:
9641-
return hasDerivativeFunctions();
9642-
}
9643-
llvm_unreachable("invalid extractee kind");
9644-
}
9645-
9646-
/// Returns the operand corresponding to the given extractee kind.
9647-
SILValue
9648-
getExtractee(NormalDifferentiableFunctionTypeComponent extractee) const {
9649-
switch (extractee) {
9650-
case NormalDifferentiableFunctionTypeComponent::Original:
9651-
return getOriginalFunction();
9652-
case NormalDifferentiableFunctionTypeComponent::JVP:
9653-
return getJVPFunction();
9654-
case NormalDifferentiableFunctionTypeComponent::VJP:
9655-
return getVJPFunction();
9656-
}
9657-
llvm_unreachable("invalid extractee kind");
9658-
}
96599631
};
96609632

96619633
/// LinearFunctionInst - given a function, its derivative and transpose functions,
@@ -9696,31 +9668,6 @@ class LinearFunctionInst final
96969668
assert(HasTransposeFunction);
96979669
return getOperand(1);
96989670
}
9699-
9700-
9701-
/// Returns true iff the operand corresponding to the given extractee kind
9702-
/// exists.
9703-
bool hasExtractee(LinearDifferentiableFunctionTypeComponent extractee) const {
9704-
switch (extractee) {
9705-
case LinearDifferentiableFunctionTypeComponent::Original:
9706-
return true;
9707-
case LinearDifferentiableFunctionTypeComponent::Transpose:
9708-
return hasTransposeFunction();
9709-
}
9710-
llvm_unreachable("invalid extractee kind");
9711-
}
9712-
9713-
/// Returns the operand corresponding to the given extractee kind.
9714-
SILValue
9715-
getExtractee(LinearDifferentiableFunctionTypeComponent extractee) const {
9716-
switch (extractee) {
9717-
case LinearDifferentiableFunctionTypeComponent::Original:
9718-
return getOriginalFunction();
9719-
case LinearDifferentiableFunctionTypeComponent::Transpose:
9720-
return getTransposeFunction();
9721-
}
9722-
llvm_unreachable("invalid extractee kind");
9723-
}
97249671
};
97259672

97269673
/// DifferentiableFunctionExtractInst - extracts either the original or

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -263,35 +263,6 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
263263
return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices);
264264
}
265265

266-
CanSILFunctionType SILFunctionType::getDifferentiableComponentType(
267-
NormalDifferentiableFunctionTypeComponent component, SILModule &module) {
268-
assert(getDifferentiabilityKind() == DifferentiabilityKind::Reverse &&
269-
"Must be a `@differentiable(reverse)` function");
270-
auto originalFnTy = getWithoutDifferentiability();
271-
if (auto derivativeKind = component.getAsDerivativeFunctionKind()) {
272-
return originalFnTy->getAutoDiffDerivativeFunctionType(
273-
getDifferentiabilityParameterIndices(),
274-
getDifferentiabilityResultIndices(), *derivativeKind, module.Types,
275-
LookUpConformanceInModule(module.getSwiftModule()));
276-
}
277-
return originalFnTy;
278-
}
279-
280-
CanSILFunctionType SILFunctionType::getLinearComponentType(
281-
LinearDifferentiableFunctionTypeComponent component, SILModule &module) {
282-
assert(getDifferentiabilityKind() == DifferentiabilityKind::Linear &&
283-
"Must be a `@differentiable(linear)` function");
284-
auto originalFnTy = getWithoutDifferentiability();
285-
switch (component) {
286-
case LinearDifferentiableFunctionTypeComponent::Original:
287-
return originalFnTy;
288-
case LinearDifferentiableFunctionTypeComponent::Transpose:
289-
return originalFnTy->getAutoDiffTransposeFunctionType(
290-
getDifferentiabilityParameterIndices(), module.Types,
291-
LookUpConformanceInModule(module.getSwiftModule()));
292-
}
293-
}
294-
295266
CanSILFunctionType
296267
SILFunctionType::getWithDifferentiability(DifferentiabilityKind kind,
297268
IndexSubset *parameterIndices,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ using llvm::SmallSet;
6262
static llvm::cl::opt<bool> EnableExperimentalLinearMapTransposition(
6363
"enable-experimental-linear-map-transposition", llvm::cl::init(false));
6464

65+
/// This flag is used to disable `differentiable_function_extract` instruction
66+
/// folding for SIL testing purposes.
67+
static llvm::cl::opt<bool> SkipFoldingDifferentiableFunctionExtraction(
68+
"differentiation-skip-folding-differentiable-function-extraction",
69+
llvm::cl::init(true));
70+
6571
//===----------------------------------------------------------------------===//
6672
// Helpers
6773
//===----------------------------------------------------------------------===//
@@ -121,6 +127,17 @@ class DifferentiationTransformer {
121127
/// Process the given `linear_function` instruction, filling in the missing
122128
/// transpose function if necessary.
123129
bool processLinearFunctionInst(LinearFunctionInst *lfi);
130+
131+
/// Fold `differentiable_function_extract` users of the given
132+
/// `differentiable_function` instruction, directly replacing them with
133+
/// `differentiable_function` instruction operands. If the
134+
/// `differentiable_function` instruction has no remaining uses, delete the
135+
/// instruction itself after folding.
136+
///
137+
/// Folding can be disabled by the
138+
/// `SkipFoldingDifferentiableFunctionExtraction` flag for SIL testing
139+
/// purposes.
140+
void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source);
124141
};
125142

126143
} // end anonymous namespace
@@ -1222,6 +1239,51 @@ SILValue DifferentiationTransformer::promoteToLinearFunction(
12221239
return newLinearFn;
12231240
}
12241241

1242+
/// Fold `differentiable_function_extract` users of the given
1243+
/// `differentiable_function` instruction, directly replacing them with
1244+
/// `differentiable_function` instruction operands. If the
1245+
/// `differentiable_function` instruction has no remaining uses, delete the
1246+
/// instruction itself after folding.
1247+
///
1248+
/// Folding can be disabled by the `SkipFoldingDifferentiableFunctionExtraction`
1249+
/// flag for SIL testing purposes.
1250+
// FIXME: This function is not correctly detecting the foldable pattern and
1251+
// needs to be rewritten.
1252+
void DifferentiationTransformer::foldDifferentiableFunctionExtraction(
1253+
DifferentiableFunctionInst *source) {
1254+
// Iterate through all `differentiable_function` instruction uses.
1255+
for (auto use : source->getUses()) {
1256+
auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(use->getUser());
1257+
// If user is not an `differentiable_function_extract` instruction, set flag
1258+
// to false.
1259+
if (!dfei)
1260+
continue;
1261+
// Fold original function extractors.
1262+
if (dfei->getExtractee() ==
1263+
NormalDifferentiableFunctionTypeComponent::Original) {
1264+
auto originalFnValue = source->getOriginalFunction();
1265+
dfei->replaceAllUsesWith(originalFnValue);
1266+
dfei->eraseFromParent();
1267+
continue;
1268+
}
1269+
// Fold derivative function extractors.
1270+
auto derivativeFnValue =
1271+
source->getDerivativeFunction(dfei->getDerivativeFunctionKind());
1272+
dfei->replaceAllUsesWith(derivativeFnValue);
1273+
dfei->eraseFromParent();
1274+
}
1275+
// If the `differentiable_function` instruction has no remaining uses, erase
1276+
// it.
1277+
if (isInstructionTriviallyDead(source)) {
1278+
SILBuilder builder(source);
1279+
builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction());
1280+
builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction());
1281+
source->eraseFromParent();
1282+
}
1283+
// Mark `source` as processed so that it won't be reprocessed after deletion.
1284+
context.markDifferentiableFunctionInstAsProcessed(source);
1285+
}
1286+
12251287
bool DifferentiationTransformer::processDifferentiableFunctionInst(
12261288
DifferentiableFunctionInst *dfi) {
12271289
PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`",
@@ -1250,6 +1312,14 @@ bool DifferentiationTransformer::processDifferentiableFunctionInst(
12501312
// Destroy the original operand.
12511313
builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction());
12521314
dfi->eraseFromParent();
1315+
// If the promoted `@differentiable` function-typed value is an
1316+
// `differentiable_function` instruction, fold
1317+
// `differentiable_function_extract` instructions. If
1318+
// `differentiable_function_extract` folding is disabled, return.
1319+
if (!SkipFoldingDifferentiableFunctionExtraction)
1320+
if (auto *newDFI =
1321+
dyn_cast<DifferentiableFunctionInst>(differentiableFnValue))
1322+
foldDifferentiableFunctionExtraction(newDFI);
12531323
transform.invalidateAnalysis(parent,
12541324
SILAnalysis::InvalidationKind::FunctionBody);
12551325
return false;

lib/SILOptimizer/SILCombiner/SILCombiner.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ class SILCombiner :
294294
SILInstruction *visitConvertFunctionInst(ConvertFunctionInst *CFI);
295295
SILInstruction *
296296
visitConvertEscapeToNoEscapeInst(ConvertEscapeToNoEscapeInst *Cvt);
297-
SILInstruction *
298-
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *DFEI);
297+
298+
SILInstruction *legacyVisitGlobalValueInst(GlobalValueInst *globalValue);
299299

300300
#define PASS(ID, TAG, DESCRIPTION)
301301
#define SWIFT_FUNCTION_PASS(ID, TAG, DESCRIPTION)

lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp

Lines changed: 7 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,92 +1065,16 @@ visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI) {
10651065

10661066
SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
10671067
ConvertEscapeToNoEscapeInst *Cvt) {
1068-
// Rewrite conversion of `convert_function` of `thin_to_thick_function` as
1069-
// conversion of `thin_to_thick_function` of `convert_function`.
1070-
//
1071-
// (convert_escape_to_noescape (convert_function (thin_to_thick_function x)))
1072-
// =>
1073-
// (convert_escape_to_noescape (thin_to_thick_function (convert_function x)))
1074-
//
1075-
// This unblocks the `thin_to_thick_function` peephole optimization below.
1076-
if (auto *CFI = dyn_cast<ConvertFunctionInst>(Cvt->getConverted())) {
1077-
if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(CFI->getConverted())) {
1078-
if (TTTFI->getSingleUse()) {
1079-
auto convertedThickType = CFI->getType().castTo<SILFunctionType>();
1080-
auto convertedThinType = convertedThickType->getWithRepresentation(
1081-
SILFunctionTypeRepresentation::Thin);
1082-
auto *newCFI = Builder.createConvertFunction(
1083-
CFI->getLoc(), TTTFI->getConverted(),
1084-
SILType::getPrimitiveObjectType(convertedThinType),
1085-
CFI->withoutActuallyEscaping());
1086-
auto *newTTTFI = Builder.createThinToThickFunction(
1087-
TTTFI->getLoc(), newCFI, CFI->getType());
1088-
replaceInstUsesWith(*CFI, newTTTFI);
1089-
}
1090-
}
1091-
}
1092-
1093-
// Rewrite conversion of `thin_to_thick_function` as `thin_to_thick_function`
1094-
// with a noescape function type.
1095-
//
1096-
// (convert_escape_to_noescape (thin_to_thick_function x)) =>
1097-
// (thin_to_thick_function [noescape] x)
1098-
if (auto *OrigThinToThick = dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted())) {
1099-
auto origFunType = OrigThinToThick->getType().getAs<SILFunctionType>();
1100-
auto NewTy = origFunType->getWithExtInfo(origFunType->getExtInfo().withNoEscape(true));
1068+
auto *OrigThinToThick =
1069+
dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted());
1070+
if (!OrigThinToThick)
1071+
return nullptr;
1072+
auto origFunType = OrigThinToThick->getType().getAs<SILFunctionType>();
1073+
auto NewTy = origFunType->getWithExtInfo(origFunType->getExtInfo().withNoEscape(true));
11011074

1102-
return Builder.createThinToThickFunction(
1075+
return Builder.createThinToThickFunction(
11031076
OrigThinToThick->getLoc(), OrigThinToThick->getOperand(),
11041077
SILType::getPrimitiveObjectType(NewTy));
1105-
}
1106-
1107-
// Push conversion instructions inside `differentiable_function`. This
1108-
// unblocks more optimizations.
1109-
//
1110-
// Before:
1111-
// %x = differentiable_function(%orig, %jvp, %vjp)
1112-
// %y = convert_escape_to_noescape %x
1113-
//
1114-
// After:
1115-
// %orig' = convert_escape_to_noescape %orig
1116-
// %jvp' = convert_escape_to_noescape %jvp
1117-
// %vjp' = convert_escape_to_noescape %vjp
1118-
// %y = differentiable_function(%orig', %jvp', %vjp')
1119-
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted())) {
1120-
auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1121-
if (!DFI->hasExtractee(extractee))
1122-
return SILValue();
1123-
1124-
auto operand = DFI->getExtractee(extractee);
1125-
auto fnType = operand->getType().castTo<SILFunctionType>();
1126-
auto noEscapeFnType =
1127-
fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape());
1128-
auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType);
1129-
return Builder.createConvertEscapeToNoEscape(
1130-
operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0);
1131-
};
1132-
1133-
SILValue originalNoEscape =
1134-
createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original);
1135-
SILValue convertedJVP = createConvertEscapeToNoEscape(
1136-
NormalDifferentiableFunctionTypeComponent::JVP);
1137-
SILValue convertedVJP = createConvertEscapeToNoEscape(
1138-
NormalDifferentiableFunctionTypeComponent::VJP);
1139-
1140-
Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1141-
if (convertedJVP && convertedVJP)
1142-
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
1143-
1144-
auto *newDFI = Builder.createDifferentiableFunction(
1145-
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1146-
originalNoEscape, derivativeFunctions);
1147-
assert(newDFI->getType() == Cvt->getType() &&
1148-
"New `@differentiable` function instruction should have same type "
1149-
"as the old `convert_escape_to_no_escape` instruction");
1150-
return newDFI;
1151-
}
1152-
1153-
return nullptr;
11541078
}
11551079

11561080
SILInstruction *
@@ -1283,54 +1207,6 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) {
12831207
return std::move(folder).optimizeWithSetValue(subCFI->getConverted());
12841208
}
12851209

1286-
// Push conversion instructions inside `differentiable_function`. This
1287-
// unblocks more optimizations.
1288-
//
1289-
// Before:
1290-
// %x = differentiable_function(%orig, %jvp, %vjp)
1291-
// %y = convert_function %x
1292-
//
1293-
// After:
1294-
// %orig' = convert_function %orig
1295-
// %jvp' = convert_function %jvp
1296-
// %vjp' = convert_function %vjp
1297-
// %y = differentiable_function(%orig', %jvp', %vjp')
1298-
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(cfi->getConverted())) {
1299-
auto createConvertFunctionOfComponent =
1300-
[&](NormalDifferentiableFunctionTypeComponent extractee) {
1301-
if (!DFI->hasExtractee(extractee))
1302-
return SILValue();
1303-
1304-
auto operand = DFI->getExtractee(extractee);
1305-
auto convertInstType =
1306-
cfi->getType().castTo<SILFunctionType>();
1307-
auto convertedComponentFnType =
1308-
convertInstType->getDifferentiableComponentType(
1309-
extractee, Builder.getModule());
1310-
auto convertedComponentType =
1311-
SILType::getPrimitiveObjectType(convertedComponentFnType);
1312-
return Builder.createConvertFunction(
1313-
operand.getLoc(), operand, convertedComponentType,
1314-
cfi->withoutActuallyEscaping())->getResult(0);
1315-
};
1316-
SILValue convertedOriginal = createConvertFunctionOfComponent(
1317-
NormalDifferentiableFunctionTypeComponent::Original);
1318-
SILValue convertedJVP = createConvertFunctionOfComponent(
1319-
NormalDifferentiableFunctionTypeComponent::JVP);
1320-
SILValue convertedVJP = createConvertFunctionOfComponent(
1321-
NormalDifferentiableFunctionTypeComponent::VJP);
1322-
Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1323-
if (convertedJVP && convertedVJP)
1324-
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
1325-
auto *newDFI = Builder.createDifferentiableFunction(
1326-
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1327-
convertedOriginal, derivativeFunctions);
1328-
assert(newDFI->getType() == cfi->getType() &&
1329-
"New `@differentiable` function instruction should have same type "
1330-
"as the old `convert_function` instruction");
1331-
return newDFI;
1332-
}
1333-
13341210
// Replace a convert_function that only has refcounting uses with its
13351211
// operand.
13361212
tryEliminateOnlyOwnershipUsedForwardingInst(cfi, getInstModCallbacks());

0 commit comments

Comments
 (0)