@@ -1065,92 +1065,16 @@ visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *CCABI) {
1065
1065
1066
1066
SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst (
1067
1067
ConvertEscapeToNoEscapeInst *Cvt) {
1068
- // Rewrite conversion of `convert_function` of `thin_to_thick_function` as
1069
- // conversion of `thin_to_thick_function` of `convert_function`.
1070
- //
1071
- // (convert_escape_to_noescape (convert_function (thin_to_thick_function x)))
1072
- // =>
1073
- // (convert_escape_to_noescape (thin_to_thick_function (convert_function x)))
1074
- //
1075
- // This unblocks the `thin_to_thick_function` peephole optimization below.
1076
- if (auto *CFI = dyn_cast<ConvertFunctionInst>(Cvt->getConverted ())) {
1077
- if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(CFI->getConverted ())) {
1078
- if (TTTFI->getSingleUse ()) {
1079
- auto convertedThickType = CFI->getType ().castTo <SILFunctionType>();
1080
- auto convertedThinType = convertedThickType->getWithRepresentation (
1081
- SILFunctionTypeRepresentation::Thin);
1082
- auto *newCFI = Builder.createConvertFunction (
1083
- CFI->getLoc (), TTTFI->getConverted (),
1084
- SILType::getPrimitiveObjectType (convertedThinType),
1085
- CFI->withoutActuallyEscaping ());
1086
- auto *newTTTFI = Builder.createThinToThickFunction (
1087
- TTTFI->getLoc (), newCFI, CFI->getType ());
1088
- replaceInstUsesWith (*CFI, newTTTFI);
1089
- }
1090
- }
1091
- }
1092
-
1093
- // Rewrite conversion of `thin_to_thick_function` as `thin_to_thick_function`
1094
- // with a noescape function type.
1095
- //
1096
- // (convert_escape_to_noescape (thin_to_thick_function x)) =>
1097
- // (thin_to_thick_function [noescape] x)
1098
- if (auto *OrigThinToThick = dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted ())) {
1099
- auto origFunType = OrigThinToThick->getType ().getAs <SILFunctionType>();
1100
- auto NewTy = origFunType->getWithExtInfo (origFunType->getExtInfo ().withNoEscape (true ));
1068
+ auto *OrigThinToThick =
1069
+ dyn_cast<ThinToThickFunctionInst>(Cvt->getConverted ());
1070
+ if (!OrigThinToThick)
1071
+ return nullptr ;
1072
+ auto origFunType = OrigThinToThick->getType ().getAs <SILFunctionType>();
1073
+ auto NewTy = origFunType->getWithExtInfo (origFunType->getExtInfo ().withNoEscape (true ));
1101
1074
1102
- return Builder.createThinToThickFunction (
1075
+ return Builder.createThinToThickFunction (
1103
1076
OrigThinToThick->getLoc (), OrigThinToThick->getOperand (),
1104
1077
SILType::getPrimitiveObjectType (NewTy));
1105
- }
1106
-
1107
- // Push conversion instructions inside `differentiable_function`. This
1108
- // unblocks more optimizations.
1109
- //
1110
- // Before:
1111
- // %x = differentiable_function(%orig, %jvp, %vjp)
1112
- // %y = convert_escape_to_noescape %x
1113
- //
1114
- // After:
1115
- // %orig' = convert_escape_to_noescape %orig
1116
- // %jvp' = convert_escape_to_noescape %jvp
1117
- // %vjp' = convert_escape_to_noescape %vjp
1118
- // %y = differentiable_function(%orig', %jvp', %vjp')
1119
- if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getConverted ())) {
1120
- auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1121
- if (!DFI->hasExtractee (extractee))
1122
- return SILValue ();
1123
-
1124
- auto operand = DFI->getExtractee (extractee);
1125
- auto fnType = operand->getType ().castTo <SILFunctionType>();
1126
- auto noEscapeFnType =
1127
- fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1128
- auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1129
- return Builder.createConvertEscapeToNoEscape (
1130
- operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1131
- };
1132
-
1133
- SILValue originalNoEscape =
1134
- createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1135
- SILValue convertedJVP = createConvertEscapeToNoEscape (
1136
- NormalDifferentiableFunctionTypeComponent::JVP);
1137
- SILValue convertedVJP = createConvertEscapeToNoEscape (
1138
- NormalDifferentiableFunctionTypeComponent::VJP);
1139
-
1140
- Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1141
- if (convertedJVP && convertedVJP)
1142
- derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1143
-
1144
- auto *newDFI = Builder.createDifferentiableFunction (
1145
- DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1146
- originalNoEscape, derivativeFunctions);
1147
- assert (newDFI->getType () == Cvt->getType () &&
1148
- " New `@differentiable` function instruction should have same type "
1149
- " as the old `convert_escape_to_no_escape` instruction" );
1150
- return newDFI;
1151
- }
1152
-
1153
- return nullptr ;
1154
1078
}
1155
1079
1156
1080
SILInstruction *
@@ -1283,54 +1207,6 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) {
1283
1207
return std::move (folder).optimizeWithSetValue (subCFI->getConverted ());
1284
1208
}
1285
1209
1286
- // Push conversion instructions inside `differentiable_function`. This
1287
- // unblocks more optimizations.
1288
- //
1289
- // Before:
1290
- // %x = differentiable_function(%orig, %jvp, %vjp)
1291
- // %y = convert_function %x
1292
- //
1293
- // After:
1294
- // %orig' = convert_function %orig
1295
- // %jvp' = convert_function %jvp
1296
- // %vjp' = convert_function %vjp
1297
- // %y = differentiable_function(%orig', %jvp', %vjp')
1298
- if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(cfi->getConverted ())) {
1299
- auto createConvertFunctionOfComponent =
1300
- [&](NormalDifferentiableFunctionTypeComponent extractee) {
1301
- if (!DFI->hasExtractee (extractee))
1302
- return SILValue ();
1303
-
1304
- auto operand = DFI->getExtractee (extractee);
1305
- auto convertInstType =
1306
- cfi->getType ().castTo <SILFunctionType>();
1307
- auto convertedComponentFnType =
1308
- convertInstType->getDifferentiableComponentType (
1309
- extractee, Builder.getModule ());
1310
- auto convertedComponentType =
1311
- SILType::getPrimitiveObjectType (convertedComponentFnType);
1312
- return Builder.createConvertFunction (
1313
- operand.getLoc (), operand, convertedComponentType,
1314
- cfi->withoutActuallyEscaping ())->getResult (0 );
1315
- };
1316
- SILValue convertedOriginal = createConvertFunctionOfComponent (
1317
- NormalDifferentiableFunctionTypeComponent::Original);
1318
- SILValue convertedJVP = createConvertFunctionOfComponent (
1319
- NormalDifferentiableFunctionTypeComponent::JVP);
1320
- SILValue convertedVJP = createConvertFunctionOfComponent (
1321
- NormalDifferentiableFunctionTypeComponent::VJP);
1322
- Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1323
- if (convertedJVP && convertedVJP)
1324
- derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1325
- auto *newDFI = Builder.createDifferentiableFunction (
1326
- DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1327
- convertedOriginal, derivativeFunctions);
1328
- assert (newDFI->getType () == cfi->getType () &&
1329
- " New `@differentiable` function instruction should have same type "
1330
- " as the old `convert_function` instruction" );
1331
- return newDFI;
1332
- }
1333
-
1334
1210
// Replace a convert_function that only has refcounting uses with its
1335
1211
// operand.
1336
1212
tryEliminateOnlyOwnershipUsedForwardingInst (cfi, getInstModCallbacks ());
0 commit comments