Skip to content

Commit 47201cb

Browse files
authored
Merge pull request #58713 from jckarter/sil-combiner-apply-result-conversions
SILCombine: Handle result conversions for apply (convert_function) peephole
2 parents 0924a06 + 2a668b9 commit 47201cb

File tree

5 files changed

+226
-65
lines changed

5 files changed

+226
-65
lines changed

lib/IRGen/LoadableByAddress.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,8 +1894,8 @@ static void allocateAndSetAll(StructLoweringState &pass,
18941894
}
18951895
}
18961896

1897-
static void castTupleInstr(SingleValueInstruction *instr, IRGenModule &Mod,
1898-
LargeSILTypeMapper &Mapper) {
1897+
static void retypeTupleInstr(SingleValueInstruction *instr, IRGenModule &Mod,
1898+
LargeSILTypeMapper &Mapper) {
18991899
SILType currSILType = instr->getType();
19001900
auto funcType = getInnerFunctionType(currSILType);
19011901
assert(funcType && "Expected a function Type");
@@ -1910,25 +1910,31 @@ static void castTupleInstr(SingleValueInstruction *instr, IRGenModule &Mod,
19101910

19111911
auto II = instr->getIterator();
19121912
++II;
1913-
SILBuilderWithScope castBuilder(II);
1914-
SingleValueInstruction *castInstr = nullptr;
1913+
SILBuilderWithScope Builder(II);
1914+
SingleValueInstruction *newInstr = nullptr;
19151915
switch (instr->getKind()) {
19161916
// Add cast to the new sil function type:
19171917
case SILInstructionKind::TupleExtractInst: {
1918-
castInstr = castBuilder.createUncheckedReinterpretCast(
1919-
instr->getLoc(), instr, newSILType.getObjectType());
1918+
auto extractInst = cast<TupleExtractInst>(instr);
1919+
newInstr = Builder.createTupleExtract(
1920+
extractInst->getLoc(), extractInst->getOperand(),
1921+
extractInst->getFieldIndex(),
1922+
newSILType.getObjectType());
19201923
break;
19211924
}
19221925
case SILInstructionKind::TupleElementAddrInst: {
1923-
castInstr = castBuilder.createUncheckedAddrCast(
1924-
instr->getLoc(), instr, newSILType.getAddressType());
1926+
auto elementAddrInst = cast<TupleElementAddrInst>(instr);
1927+
newInstr = Builder.createTupleElementAddr(
1928+
elementAddrInst->getLoc(), elementAddrInst->getOperand(),
1929+
elementAddrInst->getFieldIndex(),
1930+
newSILType.getAddressType());
19251931
break;
19261932
}
19271933
default:
19281934
llvm_unreachable("Unexpected instruction inside tupleInstsToMod");
19291935
}
1930-
instr->replaceAllUsesWith(castInstr);
1931-
castInstr->setOperand(0, instr);
1936+
instr->replaceAllUsesWith(newInstr);
1937+
instr->eraseFromParent();
19321938
}
19331939

19341940
static SILValue createCopyOfEnum(StructLoweringState &pass,
@@ -2090,7 +2096,7 @@ static void rewriteFunction(StructLoweringState &pass,
20902096
}
20912097

20922098
for (SingleValueInstruction *instr : pass.tupleInstsToMod) {
2093-
castTupleInstr(instr, pass.Mod, pass.Mapper);
2099+
retypeTupleInstr(instr, pass.Mod, pass.Mapper);
20942100
}
20952101

20962102
while (!pass.allocStackInstsToMod.empty()) {

lib/SIL/IR/SILBuilder.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ SILBuilder::createUncheckedReinterpretCast(SILLocation Loc, SILValue Op,
154154
if (SILType::canRefCast(Op->getType(), Ty, getModule()))
155155
return createUncheckedRefCast(Loc, Op, Ty);
156156

157+
// If the source and destination types are functions with the same
158+
// kind of representation, then do a function conversion.
159+
if (Op->getType().isObject() && Ty.isObject()) {
160+
if (auto OpFnTy = Op->getType().getAs<SILFunctionType>()) {
161+
if (auto DestFnTy = Ty.getAs<SILFunctionType>()) {
162+
if (OpFnTy->getRepresentation() == DestFnTy->getRepresentation()) {
163+
return createConvertFunction(Loc, Op, Ty, /*withoutActuallyEscaping*/ false);
164+
}
165+
}
166+
}
167+
}
168+
157169
// The destination type is nontrivial, and may be smaller than the source
158170
// type, so RC identity cannot be assumed.
159171
return insert(UncheckedBitwiseCastInst::create(
@@ -175,6 +187,18 @@ SILBuilder::createUncheckedBitCast(SILLocation Loc, SILValue Op, SILType Ty) {
175187
if (SILType::canRefCast(Op->getType(), Ty, getModule()))
176188
return createUncheckedRefCast(Loc, Op, Ty);
177189

190+
// If the source and destination types are functions with the same
191+
// kind of representation, then do a function conversion.
192+
if (Op->getType().isObject() && Ty.isObject()) {
193+
if (auto OpFnTy = Op->getType().getAs<SILFunctionType>()) {
194+
if (auto DestFnTy = Ty.getAs<SILFunctionType>()) {
195+
if (OpFnTy->getRepresentation() == DestFnTy->getRepresentation()) {
196+
return createConvertFunction(Loc, Op, Ty, /*withoutActuallyEscaping*/ false);
197+
}
198+
}
199+
}
200+
}
201+
178202
// The destination type is nontrivial, and may be smaller than the source
179203
// type, so RC identity cannot be assumed.
180204
return createUncheckedValueCast(Loc, Op, Ty);

lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -146,42 +146,20 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI,
146146
if (SubstCalleeTy->hasArchetype() || ConvertCalleeTy->hasArchetype())
147147
return nullptr;
148148

149-
// Indirect results are not currently handled.
150-
if (AI.hasIndirectSILResults())
151-
return nullptr;
152-
153-
// Bail if the result type of the converted callee is different from the callee's
154-
// result type of the apply instruction.
155-
if (SubstCalleeTy->getAllResultsSubstType(
156-
AI.getModule(), AI.getFunction()->getTypeExpansionContext()) !=
157-
ConvertCalleeTy->getAllResultsSubstType(
158-
AI.getModule(), AI.getFunction()->getTypeExpansionContext())) {
159-
return nullptr;
160-
}
161-
162149
// Ok, we can now perform our transformation. Grab AI's operands and the
163150
// relevant types from the ConvertFunction function type and AI.
164151
Builder.setCurrentDebugScope(AI.getDebugScope());
165-
OperandValueArrayRef Ops = AI.getArgumentsWithoutIndirectResults();
152+
OperandValueArrayRef Ops = AI.getArguments();
166153
SILFunctionConventions substConventions(SubstCalleeTy, FRI->getModule());
167154
SILFunctionConventions convertConventions(ConvertCalleeTy, FRI->getModule());
168155
auto context = AI.getFunction()->getTypeExpansionContext();
169-
auto oldOpTypes = substConventions.getParameterSILTypes(context);
170-
auto newOpTypes = convertConventions.getParameterSILTypes(context);
171-
172-
assert(Ops.size() == SubstCalleeTy->getNumParameters()
173-
&& "Ops and op types must have same size.");
174-
assert(Ops.size() == ConvertCalleeTy->getNumParameters()
175-
&& "Ops and op types must have same size.");
156+
auto oldOpRetTypes = substConventions.getIndirectSILResultTypes(context);
157+
auto newOpRetTypes = convertConventions.getIndirectSILResultTypes(context);
158+
auto oldOpParamTypes = substConventions.getParameterSILTypes(context);
159+
auto newOpParamTypes = convertConventions.getParameterSILTypes(context);
176160

177161
llvm::SmallVector<SILValue, 8> Args;
178-
auto newOpI = newOpTypes.begin();
179-
auto oldOpI = oldOpTypes.begin();
180-
for (unsigned i = 0, e = Ops.size(); i != e; ++i, ++newOpI, ++oldOpI) {
181-
SILValue Op = Ops[i];
182-
SILType OldOpType = *oldOpI;
183-
SILType NewOpType = *newOpI;
184-
162+
auto convertOp = [&](SILValue Op, SILType OldOpType, SILType NewOpType) {
185163
// Convert function takes refs to refs, address to addresses, and leaves
186164
// other types alone.
187165
if (OldOpType.isAddress()) {
@@ -190,17 +168,68 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI,
190168
Args.push_back(UAC);
191169
} else if (OldOpType.getASTType() != NewOpType.getASTType()) {
192170
auto URC =
193-
Builder.createUncheckedReinterpretCast(AI.getLoc(), Op, NewOpType);
171+
Builder.createUncheckedBitCast(AI.getLoc(), Op, NewOpType);
194172
Args.push_back(URC);
195173
} else {
196174
Args.push_back(Op);
197175
}
176+
};
177+
178+
unsigned OpI = 0;
179+
180+
auto newRetI = newOpRetTypes.begin();
181+
auto oldRetI = oldOpRetTypes.begin();
182+
183+
for (auto e = newOpRetTypes.end(); newRetI != e;
184+
++OpI, ++newRetI, ++oldRetI) {
185+
convertOp(Ops[OpI], *oldRetI, *newRetI);
186+
}
187+
188+
auto newParamI = newOpParamTypes.begin();
189+
auto oldParamI = oldOpParamTypes.begin();
190+
for (auto e = newOpParamTypes.end(); newParamI != e;
191+
++OpI, ++newParamI, ++oldParamI) {
192+
convertOp(Ops[OpI], *oldParamI, *newParamI);
198193
}
199194

195+
// Convert the direct results if they changed.
196+
auto oldResultTy = SubstCalleeTy
197+
->getDirectFormalResultsType(AI.getModule(),
198+
AI.getFunction()->getTypeExpansionContext());
199+
auto newResultTy = ConvertCalleeTy
200+
->getDirectFormalResultsType(AI.getModule(),
201+
AI.getFunction()->getTypeExpansionContext());
202+
200203
// Create the new apply inst.
201204
if (auto *TAI = dyn_cast<TryApplyInst>(AI)) {
205+
// If the results need to change, create a new landing block to do that
206+
// conversion.
207+
auto normalBB = TAI->getNormalBB();
208+
if (oldResultTy != newResultTy) {
209+
normalBB = AI.getFunction()->createBasicBlockBefore(TAI->getNormalBB());
210+
Builder.setInsertionPoint(normalBB);
211+
SmallVector<SILValue, 4> branchArgs;
212+
213+
auto oldOpResultTypes = substConventions.getDirectSILResultTypes(context);
214+
auto newOpResultTypes = convertConventions.getDirectSILResultTypes(context);
215+
216+
auto oldRetI = oldOpResultTypes.begin();
217+
auto newRetI = newOpResultTypes.begin();
218+
auto origArgs = TAI->getNormalBB()->getArguments();
219+
auto origArgI = origArgs.begin();
220+
for (auto e = newOpResultTypes.end(); newRetI != e;
221+
++oldRetI, ++newRetI, ++origArgI) {
222+
auto arg = normalBB->createPhiArgument(*newRetI, (*origArgI)->getOwnershipKind());
223+
auto converted = Builder.createUncheckedBitCast(AI.getLoc(),
224+
arg, *oldRetI);
225+
branchArgs.push_back(converted);
226+
}
227+
228+
Builder.createBranch(AI.getLoc(), TAI->getNormalBB(), branchArgs);
229+
}
230+
202231
return Builder.createTryApply(AI.getLoc(), FRI, SubstitutionMap(), Args,
203-
TAI->getNormalBB(), TAI->getErrorBB(),
232+
normalBB, TAI->getErrorBB(),
204233
TAI->getApplyOptions());
205234
}
206235

@@ -213,12 +242,13 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI,
213242
Options |= ApplyFlags::DoesNotThrow;
214243
ApplyInst *NAI = Builder.createApply(AI.getLoc(), FRI, SubstitutionMap(),
215244
Args, Options);
216-
assert(FullApplySite(NAI).getSubstCalleeType()->getAllResultsSubstType(
217-
AI.getModule(), AI.getFunction()->getTypeExpansionContext()) ==
218-
AI.getSubstCalleeType()->getAllResultsSubstType(
219-
AI.getModule(), AI.getFunction()->getTypeExpansionContext()) &&
220-
"Function types should be the same");
221-
return NAI;
245+
SILInstruction *result = NAI;
246+
247+
if (oldResultTy != newResultTy) {
248+
result = Builder.createUncheckedBitCast(AI.getLoc(), NAI, oldResultTy);
249+
}
250+
251+
return result;
222252
}
223253

224254
/// Try to optimize a keypath application with an apply instruction.

test/SILOptimizer/sil_combine.sil

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,14 +1420,15 @@ bb0(%0 : $MyNSObj):
14201420
return %2 : $@callee_owned () -> @owned MyNSObj
14211421
}
14221422

1423-
// Check that convert_function is not eliminated if the result type of the converted function is different from the apply result type.
1424-
// CHECK-LABEL: sil {{.*}} @do_not_peephole_convert_function : $@convention(thin) (@in AnotherClass) -> @out @callee_owned (@in ()) -> @out AnotherClass {
1425-
// CHECK: [[CF:%[0-9]+]] = convert_function
1426-
// CHECK: [[APPLY:%[0-9]+]] = apply
1427-
// CHECK: [[FUN:%[0-9]+]] = function_ref
1428-
// CHECK: [[CF:%[0-9]+]] = partial_apply [[FUN]]([[APPLY]])
1429-
// CHECK: // end sil function 'do_not_peephole_convert_function'
1430-
sil shared [transparent] [reabstraction_thunk] @do_not_peephole_convert_function : $@convention(thin) (@in AnotherClass) -> @out @callee_owned (@in ()) -> @out AnotherClass {
1423+
// Check that convert_function is eliminated if the result type of the converted function is different from the apply result type.
1424+
// CHECK-LABEL: sil {{.*}} @peephole_convert_function_result_change : $@convention(thin) (@in AnotherClass) -> @out @callee_owned (@in ()) -> @out AnotherClass {
1425+
// CHECK: [[F1:%[0-9]+]] = function_ref @curry_thunk_for_MyNSObj_self
1426+
// CHECK: [[APPLY:%[0-9]+]] = apply [[F1]]
1427+
// CHECK: [[CONV_RESULT:%[0-9]+]] = convert_function [[APPLY]]
1428+
// CHECK: [[FUN:%[0-9]+]] = function_ref @reabstraction_thunk1
1429+
// CHECK: [[CF:%[0-9]+]] = partial_apply [[FUN]]([[CONV_RESULT]])
1430+
// CHECK: // end sil function 'peephole_convert_function_result_change'
1431+
sil shared [transparent] [reabstraction_thunk] @peephole_convert_function_result_change : $@convention(thin) (@in AnotherClass) -> @out @callee_owned (@in ()) -> @out AnotherClass {
14311432
bb0(%0 : $*@callee_owned (@in ()) -> @out AnotherClass, %1 : $*AnotherClass):
14321433
// function_ref @nonobjc curry thunk of MyNSObj.self()
14331434
%2 = function_ref @curry_thunk_for_MyNSObj_self : $@convention(thin) (@owned MyNSObj) -> @owned @callee_owned () -> @owned MyNSObj
@@ -1440,7 +1441,48 @@ bb0(%0 : $*@callee_owned (@in ()) -> @out AnotherClass, %1 : $*AnotherClass):
14401441
store %8 to %0 : $*@callee_owned (@in ()) -> @out AnotherClass
14411442
%10 = tuple ()
14421443
return %10 : $()
1443-
} // end sil function 'do_not_peephole_convert_function'
1444+
} // end sil function 'peephole_convert_function_result_change'
1445+
1446+
sil @convertible_result_with_error : $@convention(thin) () -> (@owned AnotherClass, @error Error)
1447+
1448+
// TODO
1449+
sil @peephole_convert_function_result_change_with_error : $@convention(thin) () -> () {
1450+
entry:
1451+
%f = function_ref @convertible_result_with_error : $@convention(thin) () -> (@owned AnotherClass, @error Error)
1452+
%c = convert_function %f : $@convention(thin) () -> (@owned AnotherClass, @error Error) to $@convention(thin) () -> (@owned MyNSObj, @error Error)
1453+
try_apply %c() : $@convention(thin) () -> (@owned MyNSObj, @error Error), normal success, error failure
1454+
1455+
success(%r : $MyNSObj):
1456+
strong_release %r : $MyNSObj
1457+
br exit
1458+
1459+
failure(%e : $Error):
1460+
strong_release %e : $Error
1461+
br exit
1462+
1463+
exit:
1464+
return undef : $()
1465+
}
1466+
1467+
sil @convertible_indirect_result : $@convention(thin) (@guaranteed MyNSObj) -> @out AnotherClass
1468+
1469+
// CHECK-LABEL: sil @peephole_convert_function_indirect_result :
1470+
// CHECK: bb0([[A:%.*]] : $AnotherClass):
1471+
// CHECK: [[F:%.*]] = function_ref @convertible_indirect_result :
1472+
// CHECK: [[B:%.*]] = alloc_stack $MyNSObj
1473+
// CHECK: [[B_CONV:%.*]] = unchecked_addr_cast [[B]]
1474+
// CHECK: [[A_CONV:%.*]] = upcast [[A]]
1475+
// CHECK: apply [[F]]([[B_CONV]], [[A_CONV]])
1476+
sil @peephole_convert_function_indirect_result : $@convention(thin) (@guaranteed AnotherClass) -> @owned MyNSObj {
1477+
entry(%a : $AnotherClass):
1478+
%f = function_ref @convertible_indirect_result : $@convention(thin) (@guaranteed MyNSObj) -> @out AnotherClass
1479+
%c = convert_function %f : $@convention(thin) (@guaranteed MyNSObj) -> @out AnotherClass to $@convention(thin) (@guaranteed AnotherClass) -> @out MyNSObj
1480+
%b = alloc_stack $*MyNSObj
1481+
apply %c(%b, %a) : $@convention(thin) (@guaranteed AnotherClass) -> @out MyNSObj
1482+
%r = load %b : $*MyNSObj
1483+
dealloc_stack %b : $*MyNSObj
1484+
return %r : $MyNSObj
1485+
}
14441486

14451487
// CHECK-LABEL: sil @upcast_formation : $@convention(thin) (@inout E, E, @inout B) -> B {
14461488
// CHECK: bb0

0 commit comments

Comments
 (0)