@@ -1065,16 +1065,94 @@ visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI) {
1065
1065
1066
1066
SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst (
1067
1067
ConvertEscapeToNoEscapeInst *Cvt) {
1068
- auto *OrigThinToThick =
1069
- dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted ());
1070
- if (!OrigThinToThick)
1071
- return nullptr ;
1072
- auto origFunType = OrigThinToThick->getType ().getAs <SILFunctionType>();
1073
- auto NewTy = origFunType->getWithExtInfo (origFunType->getExtInfo ().withNoEscape (true ));
1068
+ // Rewrite conversion of `convert_function` of `thin_to_thick_function` as
1069
+ // conversion of `thin_to_thick_function` of `convert_function`.
1070
+ //
1071
+ // (convert_escape_to_noescape (convert_function (thin_to_thick_function x)))
1072
+ // =>
1073
+ // (convert_escape_to_noescape (thin_to_thick_function (convert_function x)))
1074
+ //
1075
+ // This unblocks the `thin_to_thick_function` peephole optimization below.
1076
+ if (auto *CFI = dyn_cast<ConvertFunctionInst>(Cvt->getConverted ())) {
1077
+ if (CFI->getSingleUse ()) {
1078
+ if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(CFI->getConverted ())) {
1079
+ if (TTTFI->getSingleUse ()) {
1080
+ auto convertedThickType = CFI->getType ().castTo <SILFunctionType>();
1081
+ auto convertedThinType = convertedThickType->getWithRepresentation (
1082
+ SILFunctionTypeRepresentation::Thin);
1083
+ auto *newCFI = Builder.createConvertFunction (
1084
+ CFI->getLoc (), TTTFI->getConverted (),
1085
+ SILType::getPrimitiveObjectType (convertedThinType),
1086
+ CFI->withoutActuallyEscaping ());
1087
+ auto *newTTTFI = Builder.createThinToThickFunction (
1088
+ TTTFI->getLoc (), newCFI, CFI->getType ());
1089
+ replaceInstUsesWith (*CFI, newTTTFI);
1090
+ }
1091
+ }
1092
+ }
1093
+ }
1094
+
1095
+ // Rewrite conversion of `thin_to_thick_function` as `thin_to_thick_function`
1096
+ // with a noescape function type.
1097
+ //
1098
+ // (convert_escape_to_noescape (thin_to_thick_function x)) =>
1099
+ // (thin_to_thick_function [noescape] x)
1100
+ if (auto *OrigThinToThick = dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted ())) {
1101
+ auto origFunType = OrigThinToThick->getType ().getAs <SILFunctionType>();
1102
+ auto NewTy = origFunType->getWithExtInfo (origFunType->getExtInfo ().withNoEscape (true ));
1074
1103
1075
- return Builder.createThinToThickFunction (
1104
+ return Builder.createThinToThickFunction (
1076
1105
OrigThinToThick->getLoc (), OrigThinToThick->getOperand (),
1077
1106
SILType::getPrimitiveObjectType (NewTy));
1107
+ }
1108
+
1109
+ // Push conversion instructions inside `differentiable_function`. This
1110
+ // unblocks more optimizations.
1111
+ //
1112
+ // Before:
1113
+ // %x = differentiable_function(%orig, %jvp, %vjp)
1114
+ // %y = convert_escape_to_noescape %x
1115
+ //
1116
+ // After:
1117
+ // %orig' = convert_escape_to_noescape %orig
1118
+ // %jvp' = convert_escape_to_noescape %jvp
1119
+ // %vjp' = convert_escape_to_noescape %vjp
1120
+ // %y = differentiable_function(%orig', %jvp', %vjp')
1121
+ if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted ())) {
1122
+ auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1123
+ if (!DFI->hasExtractee (extractee))
1124
+ return SILValue ();
1125
+
1126
+ auto operand = DFI->getExtractee (extractee);
1127
+ auto fnType = operand->getType ().castTo <SILFunctionType>();
1128
+ auto noEscapeFnType =
1129
+ fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1130
+ auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1131
+ return Builder.createConvertEscapeToNoEscape (
1132
+ operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1133
+ };
1134
+
1135
+ SILValue originalNoEscape =
1136
+ createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1137
+ SILValue convertedJVP = createConvertEscapeToNoEscape (
1138
+ NormalDifferentiableFunctionTypeComponent::JVP);
1139
+ SILValue convertedVJP = createConvertEscapeToNoEscape (
1140
+ NormalDifferentiableFunctionTypeComponent::VJP);
1141
+
1142
+ Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1143
+ if (convertedJVP && convertedVJP)
1144
+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1145
+
1146
+ auto *newDFI = Builder.createDifferentiableFunction (
1147
+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1148
+ originalNoEscape, derivativeFunctions);
1149
+ assert (newDFI->getType () == Cvt->getType () &&
1150
+ " New `@differentiable` function instruction should have same type "
1151
+ " as the old `convert_escape_to_no_escape` instruction" );
1152
+ return newDFI;
1153
+ }
1154
+
1155
+ return nullptr ;
1078
1156
}
1079
1157
1080
1158
SILInstruction *
@@ -1207,6 +1285,54 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) {
1207
1285
return std::move (folder).optimizeWithSetValue (subCFI->getConverted ());
1208
1286
}
1209
1287
1288
+ // Push conversion instructions inside `differentiable_function`. This
1289
+ // unblocks more optimizations.
1290
+ //
1291
+ // Before:
1292
+ // %x = differentiable_function(%orig, %jvp, %vjp)
1293
+ // %y = convert_function %x
1294
+ //
1295
+ // After:
1296
+ // %orig' = convert_function %orig
1297
+ // %jvp' = convert_function %jvp
1298
+ // %vjp' = convert_function %vjp
1299
+ // %y = differentiable_function(%orig', %jvp', %vjp')
1300
+ if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(cfi->getConverted ())) {
1301
+ auto createConvertFunctionOfComponent =
1302
+ [&](NormalDifferentiableFunctionTypeComponent extractee) {
1303
+ if (!DFI->hasExtractee (extractee))
1304
+ return SILValue ();
1305
+
1306
+ auto operand = DFI->getExtractee (extractee);
1307
+ auto convertInstType =
1308
+ cfi->getType ().castTo <SILFunctionType>();
1309
+ auto convertedComponentFnType =
1310
+ convertInstType->getDifferentiableComponentType (
1311
+ extractee, Builder.getModule ());
1312
+ auto convertedComponentType =
1313
+ SILType::getPrimitiveObjectType (convertedComponentFnType);
1314
+ return Builder.createConvertFunction (
1315
+ operand.getLoc (), operand, convertedComponentType,
1316
+ cfi->withoutActuallyEscaping ())->getResult (0 );
1317
+ };
1318
+ SILValue convertedOriginal = createConvertFunctionOfComponent (
1319
+ NormalDifferentiableFunctionTypeComponent::Original);
1320
+ SILValue convertedJVP = createConvertFunctionOfComponent (
1321
+ NormalDifferentiableFunctionTypeComponent::JVP);
1322
+ SILValue convertedVJP = createConvertFunctionOfComponent (
1323
+ NormalDifferentiableFunctionTypeComponent::VJP);
1324
+ Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1325
+ if (convertedJVP && convertedVJP)
1326
+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1327
+ auto *newDFI = Builder.createDifferentiableFunction (
1328
+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1329
+ convertedOriginal, derivativeFunctions);
1330
+ assert (newDFI->getType () == cfi->getType () &&
1331
+ " New `@differentiable` function instruction should have same type "
1332
+ " as the old `convert_function` instruction" );
1333
+ return newDFI;
1334
+ }
1335
+
1210
1336
// Replace a convert_function that only has refcounting uses with its
1211
1337
// operand.
1212
1338
tryEliminateOnlyOwnershipUsedForwardingInst (cfi, getInstModCallbacks ());
0 commit comments