|
56 | 56 | //===----------------------------------------------------------------------===//
|
57 | 57 |
|
58 | 58 | #define DEBUG_TYPE "closure-specialization"
|
| 59 | +#include "swift/SILOptimizer/IPO/ClosureSpecializer.h" |
59 | 60 | #include "swift/Basic/Range.h"
|
| 61 | +#include "swift/Demangling/Demangle.h" |
60 | 62 | #include "swift/Demangling/Demangler.h"
|
61 | 63 | #include "swift/SIL/InstructionUtils.h"
|
62 | 64 | #include "swift/SIL/SILCloner.h"
|
@@ -103,6 +105,101 @@ static bool isSupportedClosureKind(const SILInstruction *I) {
|
103 | 105 | return isa<ThinToThickFunctionInst>(I) || isa<PartialApplyInst>(I);
|
104 | 106 | }
|
105 | 107 |
|
| 108 | +static const int SpecializationLevelLimit = 2; |
| 109 | + |
| 110 | +static int getSpecializationLevelRecursive(StringRef funcName, |
| 111 | + Demangler &parent) { |
| 112 | + using namespace Demangle; |
| 113 | + |
| 114 | + Demangler demangler; |
| 115 | + demangler.providePreallocatedMemory(parent); |
| 116 | + |
| 117 | + // Check for this kind of node tree: |
| 118 | + // |
| 119 | + // kind=Global |
| 120 | + // kind=FunctionSignatureSpecialization |
| 121 | + // kind=SpecializationPassID, index=1 |
| 122 | + // kind=FunctionSignatureSpecializationParam |
| 123 | + // kind=FunctionSignatureSpecializationParamKind, index=5 |
| 124 | + // kind=FunctionSignatureSpecializationParamPayload, text="..." |
| 125 | + // |
| 126 | + Node *root = demangler.demangleSymbol(funcName); |
| 127 | + if (!root) |
| 128 | + return 0; |
| 129 | + if (root->getKind() != Node::Kind::Global) |
| 130 | + return 0; |
| 131 | + Node *funcSpec = root->getFirstChild(); |
| 132 | + if (!funcSpec || funcSpec->getNumChildren() < 2) |
| 133 | + return 0; |
| 134 | + if (funcSpec->getKind() != Node::Kind::FunctionSignatureSpecialization) |
| 135 | + return 0; |
| 136 | + |
| 137 | + // Match any function specialization. We check for constant propagation at the |
| 138 | + // parameter level. |
| 139 | + Node *param = funcSpec->getChild(0); |
| 140 | + if (param->getKind() != Node::Kind::SpecializationPassID) |
| 141 | + return SpecializationLevelLimit + 1; // unrecognized format |
| 142 | + |
| 143 | + unsigned maxParamLevel = 0; |
| 144 | + for (unsigned paramIdx = 1; paramIdx < funcSpec->getNumChildren(); |
| 145 | + ++paramIdx) { |
| 146 | + Node *param = funcSpec->getChild(paramIdx); |
| 147 | + if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam) |
| 148 | + return SpecializationLevelLimit + 1; // unrecognized format |
| 149 | + |
| 150 | + // A parameter is recursive if it has a kind with index and type payload |
| 151 | + if (param->getNumChildren() < 2) |
| 152 | + continue; |
| 153 | + |
| 154 | + Node *kindNd = param->getChild(0); |
| 155 | + if (kindNd->getKind() != |
| 156 | + Node::Kind::FunctionSignatureSpecializationParamKind) { |
| 157 | + return SpecializationLevelLimit + 1; // unrecognized format |
| 158 | + } |
| 159 | + auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex()); |
| 160 | + if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction) |
| 161 | + continue; |
| 162 | + Node *payload = param->getChild(1); |
| 163 | + if (payload->getKind() != |
| 164 | + Node::Kind::FunctionSignatureSpecializationParamPayload) { |
| 165 | + return SpecializationLevelLimit + 1; // unrecognized format |
| 166 | + } |
| 167 | + // Check if the specialized function is a specialization itself. |
| 168 | + unsigned paramLevel = |
| 169 | + 1 + getSpecializationLevelRecursive(payload->getText(), demangler); |
| 170 | + if (paramLevel > maxParamLevel) |
| 171 | + maxParamLevel = paramLevel; |
| 172 | + } |
| 173 | + return maxParamLevel; |
| 174 | +} |
| 175 | + |
| 176 | +//===----------------------------------------------------------------------===// |
| 177 | +// Publicly visible for bridging |
| 178 | +//===----------------------------------------------------------------------===// |
| 179 | + |
| 180 | +int swift::getSpecializationLevel(SILFunction *f) { |
| 181 | + Demangle::StackAllocatedDemangler<1024> demangler; |
| 182 | + return getSpecializationLevelRecursive(f->getName(), demangler); |
| 183 | +} |
| 184 | + |
| 185 | +bool swift::isDifferentiableFuncComponent( |
| 186 | + SILFunction *f, AutoDiffFunctionComponent expectedComponent) { |
| 187 | + Demangle::Context Ctx; |
| 188 | + if (auto *root = Ctx.demangleSymbolAsNode(f->getName())) { |
| 189 | + if (auto *node = |
| 190 | + root->findByKind(Demangle::Node::Kind::AutoDiffFunctionKind, 3)) { |
| 191 | + if (node->hasIndex()) { |
| 192 | + auto component = (char)node->getIndex(); |
| 193 | + if (component == (char)expectedComponent) { |
| 194 | + return true; |
| 195 | + } |
| 196 | + } |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + return false; |
| 201 | +} |
| 202 | + |
106 | 203 | //===----------------------------------------------------------------------===//
|
107 | 204 | // Closure Spec Cloner Interface
|
108 | 205 | //===----------------------------------------------------------------------===//
|
@@ -1084,82 +1181,6 @@ static bool canSpecializeFullApplySite(FullApplySiteKind kind) {
|
1084 | 1181 | llvm_unreachable("covered switch");
|
1085 | 1182 | }
|
1086 | 1183 |
|
1087 |
| -const int SpecializationLevelLimit = 2; |
1088 |
| - |
1089 |
| -static int getSpecializationLevelRecursive(StringRef funcName, Demangler &parent) { |
1090 |
| - using namespace Demangle; |
1091 |
| - |
1092 |
| - Demangler demangler; |
1093 |
| - demangler.providePreallocatedMemory(parent); |
1094 |
| - |
1095 |
| - // Check for this kind of node tree: |
1096 |
| - // |
1097 |
| - // kind=Global |
1098 |
| - // kind=FunctionSignatureSpecialization |
1099 |
| - // kind=SpecializationPassID, index=1 |
1100 |
| - // kind=FunctionSignatureSpecializationParam |
1101 |
| - // kind=FunctionSignatureSpecializationParamKind, index=5 |
1102 |
| - // kind=FunctionSignatureSpecializationParamPayload, text="..." |
1103 |
| - // |
1104 |
| - Node *root = demangler.demangleSymbol(funcName); |
1105 |
| - if (!root) |
1106 |
| - return 0; |
1107 |
| - if (root->getKind() != Node::Kind::Global) |
1108 |
| - return 0; |
1109 |
| - Node *funcSpec = root->getFirstChild(); |
1110 |
| - if (!funcSpec || funcSpec->getNumChildren() < 2) |
1111 |
| - return 0; |
1112 |
| - if (funcSpec->getKind() != Node::Kind::FunctionSignatureSpecialization) |
1113 |
| - return 0; |
1114 |
| - |
1115 |
| - // Match any function specialization. We check for constant propagation at the |
1116 |
| - // parameter level. |
1117 |
| - Node *param = funcSpec->getChild(0); |
1118 |
| - if (param->getKind() != Node::Kind::SpecializationPassID) |
1119 |
| - return SpecializationLevelLimit + 1; // unrecognized format |
1120 |
| - |
1121 |
| - unsigned maxParamLevel = 0; |
1122 |
| - for (unsigned paramIdx = 1; paramIdx < funcSpec->getNumChildren(); |
1123 |
| - ++paramIdx) { |
1124 |
| - Node *param = funcSpec->getChild(paramIdx); |
1125 |
| - if (param->getKind() != Node::Kind::FunctionSignatureSpecializationParam) |
1126 |
| - return SpecializationLevelLimit + 1; // unrecognized format |
1127 |
| - |
1128 |
| - // A parameter is recursive if it has a kind with index and type payload |
1129 |
| - if (param->getNumChildren() < 2) |
1130 |
| - continue; |
1131 |
| - |
1132 |
| - Node *kindNd = param->getChild(0); |
1133 |
| - if (kindNd->getKind() |
1134 |
| - != Node::Kind::FunctionSignatureSpecializationParamKind) { |
1135 |
| - return SpecializationLevelLimit + 1; // unrecognized format |
1136 |
| - } |
1137 |
| - auto kind = FunctionSigSpecializationParamKind(kindNd->getIndex()); |
1138 |
| - if (kind != FunctionSigSpecializationParamKind::ConstantPropFunction) |
1139 |
| - continue; |
1140 |
| - Node *payload = param->getChild(1); |
1141 |
| - if (payload->getKind() |
1142 |
| - != Node::Kind::FunctionSignatureSpecializationParamPayload) { |
1143 |
| - return SpecializationLevelLimit + 1; // unrecognized format |
1144 |
| - } |
1145 |
| - // Check if the specialized function is a specialization itself. |
1146 |
| - unsigned paramLevel = |
1147 |
| - 1 + getSpecializationLevelRecursive(payload->getText(), demangler); |
1148 |
| - if (paramLevel > maxParamLevel) |
1149 |
| - maxParamLevel = paramLevel; |
1150 |
| - } |
1151 |
| - return maxParamLevel; |
1152 |
| -} |
1153 |
| - |
1154 |
| -/// If \p function is a function-signature specialization for a constant- |
1155 |
| -/// propagated function argument, returns 1. |
1156 |
| -/// If \p function is a specialization of such a specialization, returns 2. |
1157 |
| -/// And so on. |
1158 |
| -static int getSpecializationLevel(SILFunction *f) { |
1159 |
| - Demangle::StackAllocatedDemangler<1024> demangler; |
1160 |
| - return getSpecializationLevelRecursive(f->getName(), demangler); |
1161 |
| -} |
1162 |
| - |
1163 | 1184 | bool SILClosureSpecializerTransform::gatherCallSites(
|
1164 | 1185 | SILFunction *Caller,
|
1165 | 1186 | llvm::SmallVectorImpl<std::unique_ptr<ClosureInfo>> &ClosureCandidates,
|
|
0 commit comments