Skip to content

Commit 90a114e

Browse files
authored
AddressLowering: Support checked_cast_br (#42487)
1 parent e12cc13 commit 90a114e

File tree

2 files changed

+198
-30
lines changed

2 files changed

+198
-30
lines changed

lib/SILOptimizer/Mandatory/AddressLowering.cpp

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,11 @@ static Operand *getReusedStorageOperand(SILValue value) {
822822
if (auto *switchEnum = dyn_cast<SwitchEnumInst>(term)) {
823823
return &switchEnum->getAllOperands()[0];
824824
}
825+
if (auto *checkedCastBr = dyn_cast<CheckedCastBranchInst>(term)) {
826+
if (value->getParentBlock() == checkedCastBr->getFailureBB()) {
827+
return &checkedCastBr->getAllOperands()[0];
828+
}
829+
}
825830
}
826831
break;
827832
}
@@ -2561,36 +2566,10 @@ class UseRewriter : SILInstructionVisitor<UseRewriter> {
25612566
}
25622567

25632568
void visitUnconditionalCheckedCastInst(
2564-
UnconditionalCheckedCastInst *uncondCheckedCast) {
2565-
SILValue srcVal = uncondCheckedCast->getOperand();
2566-
assert(srcVal->getType().isAddressOnly(*pass.function));
2567-
SILValue srcAddr = pass.valueStorageMap.getStorage(srcVal).storageAddress;
2569+
UnconditionalCheckedCastInst *uncondCheckedCast);
2570+
2571+
void visitCheckedCastBranchInst(CheckedCastBranchInst *checkedCastBranch);
25682572

2569-
if (uncondCheckedCast->getType().isAddressOnly(*pass.function)) {
2570-
// When cast destination has address only type, use the storage address
2571-
SILValue destAddr = addrMat.materializeAddress(uncondCheckedCast);
2572-
markRewritten(uncondCheckedCast, destAddr);
2573-
builder.createUnconditionalCheckedCastAddr(
2574-
uncondCheckedCast->getLoc(), srcAddr, srcAddr->getType().getASTType(),
2575-
destAddr, destAddr->getType().getASTType());
2576-
return;
2577-
}
2578-
// For loadable cast destination type, create a stack temporary
2579-
SILValue destAddr = builder.createAllocStack(uncondCheckedCast->getLoc(),
2580-
uncondCheckedCast->getType());
2581-
builder.createUnconditionalCheckedCastAddr(
2582-
uncondCheckedCast->getLoc(), srcAddr, srcAddr->getType().getASTType(),
2583-
destAddr, destAddr->getType().getASTType());
2584-
auto nextBuilder =
2585-
pass.getBuilder(uncondCheckedCast->getNextInstruction()->getIterator());
2586-
auto dest = nextBuilder.createLoad(
2587-
uncondCheckedCast->getLoc(), destAddr,
2588-
destAddr->getType().isTrivial(*uncondCheckedCast->getFunction())
2589-
? LoadOwnershipQualifier::Trivial
2590-
: LoadOwnershipQualifier::Copy);
2591-
nextBuilder.createDeallocStack(uncondCheckedCast->getLoc(), destAddr);
2592-
uncondCheckedCast->replaceAllUsesWith(dest);
2593-
}
25942573
void visitUncheckedEnumDataInst(UncheckedEnumDataInst *enumDataInst);
25952574
};
25962575
} // end anonymous namespace
@@ -2832,6 +2811,89 @@ void UseRewriter::visitSwitchEnumInst(SwitchEnumInst * switchEnum) {
28322811
defaultCounter);
28332812
}
28342813

2814+
void UseRewriter::visitCheckedCastBranchInst(
2815+
CheckedCastBranchInst *checkedCastBranch) {
2816+
auto loc = checkedCastBranch->getLoc();
2817+
auto *func = checkedCastBranch->getFunction();
2818+
auto *successBB = checkedCastBranch->getSuccessBB();
2819+
auto *failureBB = checkedCastBranch->getFailureBB();
2820+
auto *oldSuccessVal = successBB->getArgument(0);
2821+
auto *oldFailureVal = failureBB->getArgument(0);
2822+
auto termBuilder = pass.getTermBuilder(checkedCastBranch);
2823+
auto successBuilder = pass.getBuilder(successBB->begin());
2824+
auto failureBuilder = pass.getBuilder(failureBB->begin());
2825+
bool isAddressOnlyTarget = oldSuccessVal->getType().isAddressOnly(*func);
2826+
2827+
auto srcAddr = pass.valueStorageMap.getStorage(use->get()).storageAddress;
2828+
2829+
if (isAddressOnlyTarget) {
2830+
// If target is opaque, use the storage address mapped to success
2831+
// block's argument as the destination for checked_cast_addr_br.
2832+
SILValue destAddr =
2833+
pass.valueStorageMap.getStorage(oldSuccessVal).storageAddress;
2834+
2835+
termBuilder.createCheckedCastAddrBranch(
2836+
loc, CastConsumptionKind::TakeOnSuccess, srcAddr,
2837+
checkedCastBranch->getSourceFormalType(), destAddr,
2838+
checkedCastBranch->getTargetFormalType(), successBB, failureBB,
2839+
checkedCastBranch->getTrueBBCount(),
2840+
checkedCastBranch->getFalseBBCount());
2841+
2842+
// In this case, since both success and failure block's args are opaque,
2843+
// create dummy loads from their storage addresses that will later be
2844+
// rewritten to copy_addr in DefRewriter::visitLoadInst
2845+
auto newSuccessVal = successBuilder.createTrivialLoadOr(
2846+
loc, destAddr, LoadOwnershipQualifier::Take);
2847+
oldSuccessVal->replaceAllUsesWith(newSuccessVal);
2848+
successBB->eraseArgument(0);
2849+
2850+
pass.valueStorageMap.replaceValue(oldSuccessVal, newSuccessVal);
2851+
2852+
auto newFailureVal = failureBuilder.createTrivialLoadOr(
2853+
loc, srcAddr, LoadOwnershipQualifier::Take);
2854+
oldFailureVal->replaceAllUsesWith(newFailureVal);
2855+
failureBB->eraseArgument(0);
2856+
2857+
pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal);
2858+
markRewritten(newFailureVal, srcAddr);
2859+
} else {
2860+
// If the target is loadable, create a stack temporary to be used as the
2861+
// destination for checked_cast_addr_br.
2862+
SILValue destAddr = termBuilder.createAllocStack(
2863+
loc, checkedCastBranch->getTargetLoweredType());
2864+
2865+
termBuilder.createCheckedCastAddrBranch(
2866+
loc, CastConsumptionKind::TakeOnSuccess, srcAddr,
2867+
checkedCastBranch->getSourceFormalType(), destAddr,
2868+
checkedCastBranch->getTargetFormalType(), successBB, failureBB,
2869+
checkedCastBranch->getTrueBBCount(),
2870+
checkedCastBranch->getFalseBBCount());
2871+
2872+
// Replace the success block arg with loaded value from destAddr, and delete
2873+
// the success block arg.
2874+
auto newSuccessVal = successBuilder.createTrivialLoadOr(
2875+
loc, destAddr, LoadOwnershipQualifier::Take);
2876+
oldSuccessVal->replaceAllUsesWith(newSuccessVal);
2877+
successBB->eraseArgument(0);
2878+
2879+
successBuilder.createDeallocStack(loc, destAddr);
2880+
failureBuilder.createDeallocStack(loc, destAddr);
2881+
2882+
// Since failure block arg is opaque, create dummy load from its storage
2883+
// address. This will be replaced later with copy_addr in
2884+
// DefRewriter::visitLoadInst.
2885+
auto newFailureVal = failureBuilder.createTrivialLoadOr(
2886+
loc, srcAddr, LoadOwnershipQualifier::Take);
2887+
oldFailureVal->replaceAllUsesWith(newFailureVal);
2888+
failureBB->eraseArgument(0);
2889+
2890+
pass.valueStorageMap.replaceValue(oldFailureVal, newFailureVal);
2891+
markRewritten(newFailureVal, srcAddr);
2892+
}
2893+
2894+
pass.deleter.forceDelete(checkedCastBranch);
2895+
}
2896+
28352897
void UseRewriter::visitUncheckedEnumDataInst(
28362898
UncheckedEnumDataInst *enumDataInst) {
28372899
assert(use == getReusedStorageOperand(enumDataInst));
@@ -2852,6 +2914,38 @@ void UseRewriter::visitUncheckedEnumDataInst(
28522914
markRewritten(enumDataInst, enumAddrInst);
28532915
}
28542916

2917+
void UseRewriter::visitUnconditionalCheckedCastInst(
2918+
UnconditionalCheckedCastInst *uncondCheckedCast) {
2919+
SILValue srcVal = uncondCheckedCast->getOperand();
2920+
assert(srcVal->getType().isAddressOnly(*pass.function));
2921+
SILValue srcAddr = pass.valueStorageMap.getStorage(srcVal).storageAddress;
2922+
2923+
if (uncondCheckedCast->getType().isAddressOnly(*pass.function)) {
2924+
// When cast destination has address only type, use the storage address
2925+
SILValue destAddr = addrMat.materializeAddress(uncondCheckedCast);
2926+
markRewritten(uncondCheckedCast, destAddr);
2927+
builder.createUnconditionalCheckedCastAddr(
2928+
uncondCheckedCast->getLoc(), srcAddr, srcAddr->getType().getASTType(),
2929+
destAddr, destAddr->getType().getASTType());
2930+
return;
2931+
}
2932+
// For loadable cast destination type, create a stack temporary
2933+
SILValue destAddr = builder.createAllocStack(uncondCheckedCast->getLoc(),
2934+
uncondCheckedCast->getType());
2935+
builder.createUnconditionalCheckedCastAddr(
2936+
uncondCheckedCast->getLoc(), srcAddr, srcAddr->getType().getASTType(),
2937+
destAddr, destAddr->getType().getASTType());
2938+
auto nextBuilder =
2939+
pass.getBuilder(uncondCheckedCast->getNextInstruction()->getIterator());
2940+
auto dest = nextBuilder.createLoad(
2941+
uncondCheckedCast->getLoc(), destAddr,
2942+
destAddr->getType().isTrivial(*uncondCheckedCast->getFunction())
2943+
? LoadOwnershipQualifier::Trivial
2944+
: LoadOwnershipQualifier::Copy);
2945+
nextBuilder.createDeallocStack(uncondCheckedCast->getLoc(), destAddr);
2946+
uncondCheckedCast->replaceAllUsesWith(dest);
2947+
}
2948+
28552949
//===----------------------------------------------------------------------===//
28562950
// DefRewriter
28572951
//
@@ -2881,7 +2975,6 @@ class DefRewriter : SILInstructionVisitor<DefRewriter> {
28812975
static void rewriteValue(SILValue value, AddressLoweringState &pass) {
28822976
if (auto *inst = value->getDefiningInstruction()) {
28832977
DefRewriter(pass, value, inst->getIterator()).visit(inst);
2884-
28852978
} else {
28862979
// function args are already rewritten.
28872980
auto *blockArg = cast<SILPhiArgument>(value);

test/SILOptimizer/address_lowering.sil

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ sil [ossa] @takeGuaranteedObject : $@convention(thin) (@guaranteed AnyObject) ->
6161
sil [ossa] @takeIndirectClass : $@convention(thin) (@in_guaranteed C) -> ()
6262
sil [ossa] @takeTuple : $@convention(thin) <τ_0_0> (@in_guaranteed (τ_0_0, C)) -> ()
6363

64+
6465
sil [ossa] @takeIn : $@convention(thin) <T> (@in T) -> ()
6566
sil [ossa] @takeInGuaranteed : $@convention(thin) <T> (@in_guaranteed T) -> ()
6667

@@ -1170,6 +1171,80 @@ bb2(%9 : $Error):
11701171
throw %9 : $Error
11711172
}
11721173

1174+
sil [ossa] @use_T : $@convention(thin) <T> (@in T) -> ()
1175+
1176+
// CHECK-LABEL: sil [ossa] @test_checked_cast_br1 : $@convention(method) <T> (@in Any) -> () {
1177+
// CHECK: bb0(%0 : $*Any):
1178+
// CHECK: [[SUCCESS_DST:%.*]] = alloc_stack $T
1179+
// CHECK: checked_cast_addr_br take_on_success Any in %0 : $*Any to T in [[SUCCESS_DST]] : $*T, bb2, bb1
1180+
// CHECK: bb1:
1181+
// CHECK: destroy_addr %0 : $*Any
1182+
// CHECK: br bb3
1183+
// CHECK: bb2:
1184+
// CHECK: [[FUNC:%.*]] = function_ref @use_T : $@convention(thin) <τ_0_0> (@in τ_0_0) -> ()
1185+
// CHECK: apply [[FUNC]]<T>([[SUCCESS_DST]]) : $@convention(thin) <τ_0_0> (@in τ_0_0) -> ()
1186+
// CHECK: br bb3
1187+
// CHECK: bb3:
1188+
// CHECK: [[RES:%.*]] = tuple ()
1189+
// CHECK: dealloc_stack [[SUCCESS_DST]] : $*T
1190+
// CHECK: return [[RES]] : $()
1191+
// CHECK-LABEL: } // end sil function 'test_checked_cast_br1'
1192+
sil [ossa] @test_checked_cast_br1 : $@convention(method) <T> (@in Any) -> () {
1193+
bb0(%0 : @owned $Any):
1194+
checked_cast_br %0 : $Any to T, bb1, bb2
1195+
1196+
bb1(%3 : @owned $T):
1197+
%f = function_ref @use_T : $@convention(thin) <T> (@in T) -> ()
1198+
%call = apply %f<T>(%3) : $@convention(thin) <τ_0_0> (@in τ_0_0) -> ()
1199+
br bb3
1200+
1201+
bb2(%4 : @owned $Any):
1202+
destroy_value %4 : $Any
1203+
br bb3
1204+
1205+
bb3:
1206+
%31 = tuple ()
1207+
return %31 : $()
1208+
}
1209+
1210+
sil [ossa] @use_C : $@convention(thin) (@owned C) -> ()
1211+
1212+
// CHECK-LABEL: sil [ossa] @test_checked_cast_br2 : $@convention(method) (@in Any) -> () {
1213+
// CHECK: bb0(%0 : $*Any):
1214+
// CHECK: [[SUCCESS_DST:%.*]] = alloc_stack $C
1215+
// CHECK: checked_cast_addr_br take_on_success Any in %0 : $*Any to C in [[SUCCESS_DST]] : $*C, bb2, bb1
1216+
// CHECK: bb1:
1217+
// CHECK: dealloc_stack [[SUCCESS_DST]] : $*C
1218+
// CHECK: destroy_addr %0 : $*Any
1219+
// CHECK: br bb3
1220+
// CHECK: bb2:
1221+
// CHECK: [[LD:%.*]] = load [take] [[SUCCESS_DST]] : $*C
1222+
// CHECK: dealloc_stack [[SUCCESS_DST]] : $*C
1223+
// CHECK: [[FUNC:%.*]] = function_ref @use_C : $@convention(thin) (@owned C) -> ()
1224+
// CHECK: apply [[FUNC]]([[LD]]) : $@convention(thin) (@owned C) -> ()
1225+
// CHECK: br bb3
1226+
// CHECK: bb3:
1227+
// CHECK: [[RES:%.*]] = tuple ()
1228+
// CHECK: return [[RES]] : $()
1229+
// CHECK-LABEL: }
1230+
sil [ossa] @test_checked_cast_br2 : $@convention(method) (@in Any) -> () {
1231+
bb0(%0 : @owned $Any):
1232+
checked_cast_br %0 : $Any to C, bb1, bb2
1233+
1234+
bb1(%3 : @owned $C):
1235+
%f = function_ref @use_C : $@convention(thin) (@owned C) -> ()
1236+
%call = apply %f(%3) : $@convention(thin) (@owned C) -> ()
1237+
br bb3
1238+
1239+
bb2(%4 : @owned $Any):
1240+
destroy_value %4 : $Any
1241+
br bb3
1242+
1243+
bb3:
1244+
%31 = tuple ()
1245+
return %31 : $()
1246+
}
1247+
11731248
// CHECK-LABEL: sil hidden [ossa] @test_unchecked_bitwise_cast :
11741249
// CHECK: bb0(%0 : $*U, %1 : $*T, %2 : $@thick U.Type):
11751250
// CHECK: [[STK:%.*]] = alloc_stack $T

0 commit comments

Comments
 (0)