Skip to content

Commit 8081482

Browse files
committed
[AutoDiff upstream] Add common SIL differentiation utilities.
1 parent bb6d4eb commit 8081482

File tree

5 files changed

+326
-1
lines changed

5 files changed

+326
-1
lines changed

include/swift/AST/SourceFile.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ class SourceFile final : public FileUnit {
394394
void cacheVisibleDecls(SmallVectorImpl<ValueDecl *> &&globals) const;
395395
const SmallVectorImpl<ValueDecl *> &getCachedVisibleDecls() const;
396396

397+
void addVisibleDecl(ValueDecl *decl);
398+
397399
virtual void lookupValue(DeclName name, NLKind lookupKind,
398400
SmallVectorImpl<ValueDecl*> &result) const override;
399401

include/swift/SIL/ApplySite.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,17 @@ class FullApplySite : public ApplySite {
510510
return getArguments().slice(getNumIndirectSILResults());
511511
}
512512

513+
InoutArgumentRange getInoutArguments() const {
514+
switch (getKind()) {
515+
case FullApplySiteKind::ApplyInst:
516+
return cast<ApplyInst>(getInstruction())->getInoutArguments();
517+
case FullApplySiteKind::TryApplyInst:
518+
return cast<TryApplyInst>(getInstruction())->getInoutArguments();
519+
case FullApplySiteKind::BeginApplyInst:
520+
return cast<BeginApplyInst>(getInstruction())->getInoutArguments();
521+
}
522+
}
523+
513524
/// Returns true if \p op is the callee operand of this apply site
514525
/// and not an argument operand.
515526
bool isCalleeOperand(const Operand &op) const {

include/swift/SILOptimizer/Utils/Differentiation/Common.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "swift/SIL/SILFunction.h"
2222
#include "swift/SIL/SILModule.h"
2323
#include "swift/SIL/TypeSubstCloner.h"
24+
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2425

2526
namespace swift {
2627

@@ -34,12 +35,77 @@ namespace autodiff {
3435
/// This is being used to print short debug messages within the AD pass.
3536
raw_ostream &getADDebugStream();
3637

38+
/// Returns true if this is an full apply site whose callee has
39+
/// `array.uninitialized_intrinsic` semantics.
40+
bool isArrayLiteralIntrinsic(FullApplySite applySite);
41+
42+
/// If the given value `v` corresponds to an `ApplyInst` with
43+
/// `array.uninitialized_intrinsic` semantics, returns the corresponding
44+
/// `ApplyInst`. Otherwise, returns `nullptr`.
45+
ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v);
46+
47+
/// Given an element address from an `array.uninitialized_intrinsic` `apply`
48+
/// instruction, returns the `apply` instruction. The element address is either
49+
/// a `pointer_to_address` or `index_addr` instruction to the `RawPointer`
50+
/// result of the instrinsic:
51+
///
52+
/// %result = apply %array.uninitialized_intrinsic : $(Array<T>, RawPointer)
53+
/// (%array, %ptr) = destructure_tuple %result
54+
/// %elt0 = pointer_to_address %ptr to $*T // element address
55+
/// %index_1 = integer_literal $Builtin.Word, 1
56+
/// %elt1 = index_addr %elt0, %index_1 // element address
57+
/// ...
58+
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
59+
60+
/// Given a value, finds its single `destructure_tuple` user if the value is
61+
/// tuple-typed and such a user exists.
62+
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
63+
64+
/// Given a full apply site, apply the given callback to each of its
65+
/// "direct results".
66+
///
67+
/// - `apply`
68+
/// Special case because `apply` returns a single (possibly tuple-typed) result
69+
/// instead of multiple results. If the `apply` has a single
70+
/// `destructure_tuple` user, treat the `destructure_tuple` results as the
71+
/// `apply` direct results.
72+
///
73+
/// - `begin_apply`
74+
/// Apply callback to each `begin_apply` direct result.
75+
///
76+
/// - `try_apply`
77+
/// Apply callback to each `try_apply` successor basic block argument.
78+
void forEachApplyDirectResult(
79+
FullApplySite applySite, llvm::function_ref<void(SILValue)> resultCallback);
80+
81+
/// Given a function, gathers all of its formal results (both direct and
82+
/// indirect) in an order defined by its result type. Note that "formal results"
83+
/// refer to result values in the body of the function, not at call sites.
84+
void collectAllFormalResultsInTypeOrder(SILFunction &function,
85+
SmallVectorImpl<SILValue> &results);
86+
87+
/// Given a function, gathers all of its direct results in an order defined by
88+
/// its result type. Note that "formal results" refer to result values in the
89+
/// body of the function, not at call sites.
90+
void collectAllDirectResultsInTypeOrder(SILFunction &function,
91+
SmallVectorImpl<SILValue> &results);
92+
3793
/// Given a function call site, gathers all of its actual results (both direct
3894
/// and indirect) in an order defined by its result type.
3995
void collectAllActualResultsInTypeOrder(
4096
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
4197
SmallVectorImpl<SILValue> &results);
4298

99+
/// For an `apply` instruction with active results, compute:
100+
/// - The results of the `apply` instruction, in type order.
101+
/// - The set of minimal parameter and result indices for differentiating the
102+
/// `apply` instruction.
103+
void collectMinimalIndicesForFunctionCall(
104+
ApplyInst *ai, SILAutoDiffIndices parentIndices,
105+
const DifferentiableActivityInfo &activityInfo,
106+
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
107+
SmallVectorImpl<unsigned> &resultIndices);
108+
43109
/// Returns the underlying instruction for the given SILValue, if it exists,
44110
/// peering through function conversion instructions.
45111
template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
@@ -58,6 +124,10 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
58124
return nullptr;
59125
}
60126

127+
//===----------------------------------------------------------------------===//
128+
// Code emission utilities
129+
//===----------------------------------------------------------------------===//
130+
61131
/// Given a range of elements, joins these into a single value. If there's
62132
/// exactly one element, returns that element. Otherwise, creates a tuple using
63133
/// a `tuple` instruction.
@@ -156,6 +226,59 @@ inline void createEntryArguments(SILFunction *f) {
156226
}
157227
}
158228

229+
/// Helper class for visiting basic blocks in post-order post-dominance order,
230+
/// based on a worklist algorithm.
231+
class PostOrderPostDominanceOrder {
232+
SmallVector<DominanceInfoNode *, 16> buffer;
233+
PostOrderFunctionInfo *postOrderInfo;
234+
size_t srcIdx = 0;
235+
236+
public:
237+
/// Constructor.
238+
/// \p root The root of the post-dominator tree.
239+
/// \p postOrderInfo The post-order info of the function.
240+
/// \p capacity Should be the number of basic blocks in the dominator tree to
241+
/// reduce memory allocation.
242+
PostOrderPostDominanceOrder(DominanceInfoNode *root,
243+
PostOrderFunctionInfo *postOrderInfo,
244+
int capacity = 0)
245+
: postOrderInfo(postOrderInfo) {
246+
buffer.reserve(capacity);
247+
buffer.push_back(root);
248+
}
249+
250+
/// Get the next block from the worklist.
251+
DominanceInfoNode *getNext() {
252+
if (srcIdx == buffer.size())
253+
return nullptr;
254+
return buffer[srcIdx++];
255+
}
256+
257+
/// Pushes the dominator children of a block onto the worklist in post-order.
258+
void pushChildren(DominanceInfoNode *node) {
259+
pushChildrenIf(node, [](SILBasicBlock *) { return true; });
260+
}
261+
262+
/// Conditionally pushes the dominator children of a block onto the worklist
263+
/// in post-order.
264+
template <typename Pred>
265+
void pushChildrenIf(DominanceInfoNode *node, Pred pred) {
266+
SmallVector<DominanceInfoNode *, 4> children;
267+
for (auto *child : *node)
268+
children.push_back(child);
269+
llvm::sort(children.begin(), children.end(),
270+
[&](DominanceInfoNode *n1, DominanceInfoNode *n2) {
271+
return postOrderInfo->getPONumber(n1->getBlock()) <
272+
postOrderInfo->getPONumber(n2->getBlock());
273+
});
274+
for (auto *child : children) {
275+
SILBasicBlock *childBB = child->getBlock();
276+
if (pred(childBB))
277+
buffer.push_back(child);
278+
}
279+
}
280+
};
281+
159282
/// Cloner that remaps types using the target function's generic environment.
160283
class BasicTypeSubstCloner final
161284
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {

lib/AST/Module.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,11 @@ SourceFile::getCachedVisibleDecls() const {
22322232
return getCache().AllVisibleValues;
22332233
}
22342234

2235+
void SourceFile::addVisibleDecl(ValueDecl *decl) {
2236+
Decls->push_back(decl);
2237+
getCache().AllVisibleValues.push_back(decl);
2238+
}
2239+
22352240
static void performAutoImport(
22362241
SourceFile &SF,
22372242
SourceFile::ImplicitModuleImportKind implicitModuleImportKind) {

lib/SILOptimizer/Utils/Differentiation/Common.cpp

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,126 @@ namespace autodiff {
2424
raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
2525

2626
//===----------------------------------------------------------------------===//
27-
// Code emission utilities
27+
// Helpers
2828
//===----------------------------------------------------------------------===//
2929

30+
bool isArrayLiteralIntrinsic(FullApplySite applySite) {
31+
return doesApplyCalleeHaveSemantics(applySite.getCalleeOrigin(),
32+
"array.uninitialized_intrinsic");
33+
}
34+
35+
ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
36+
if (auto *ai = dyn_cast<ApplyInst>(v))
37+
if (isArrayLiteralIntrinsic(ai))
38+
return ai;
39+
return nullptr;
40+
}
41+
42+
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
43+
// Find the `pointer_to_address` result, peering through `index_addr`.
44+
auto *ptai = dyn_cast<PointerToAddressInst>(v);
45+
if (auto *iai = dyn_cast<IndexAddrInst>(v))
46+
ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0));
47+
if (!ptai)
48+
return nullptr;
49+
// Return the `array.uninitialized_intrinsic` application, if it exists.
50+
if (auto *dti = dyn_cast<DestructureTupleInst>(
51+
ptai->getOperand()->getDefiningInstruction())) {
52+
if (auto *ai = getAllocateUninitializedArrayIntrinsic(dti->getOperand()))
53+
return ai;
54+
}
55+
return nullptr;
56+
}
57+
58+
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
59+
bool foundDestructureTupleUser = false;
60+
if (!value->getType().is<TupleType>())
61+
return nullptr;
62+
DestructureTupleInst *result = nullptr;
63+
for (auto *use : value->getUses()) {
64+
if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) {
65+
assert(!foundDestructureTupleUser &&
66+
"There should only be one `destructure_tuple` user of a tuple");
67+
foundDestructureTupleUser = true;
68+
result = dti;
69+
}
70+
}
71+
return result;
72+
}
73+
74+
void forEachApplyDirectResult(
75+
FullApplySite applySite,
76+
llvm::function_ref<void(SILValue)> resultCallback) {
77+
switch (applySite.getKind()) {
78+
case FullApplySiteKind::ApplyInst: {
79+
auto *ai = cast<ApplyInst>(applySite.getInstruction());
80+
if (!ai->getType().is<TupleType>()) {
81+
resultCallback(ai);
82+
return;
83+
}
84+
if (auto *dti = getSingleDestructureTupleUser(ai))
85+
for (auto directResult : dti->getResults())
86+
resultCallback(directResult);
87+
break;
88+
}
89+
case FullApplySiteKind::BeginApplyInst: {
90+
auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
91+
for (auto directResult : bai->getResults())
92+
resultCallback(directResult);
93+
break;
94+
}
95+
case FullApplySiteKind::TryApplyInst: {
96+
auto *tai = cast<TryApplyInst>(applySite.getInstruction());
97+
for (auto *succBB : tai->getSuccessorBlocks())
98+
for (auto *arg : succBB->getArguments())
99+
resultCallback(arg);
100+
break;
101+
}
102+
}
103+
}
104+
105+
void collectAllFormalResultsInTypeOrder(SILFunction &function,
106+
SmallVectorImpl<SILValue> &results) {
107+
SILFunctionConventions convs(function.getLoweredFunctionType(),
108+
function.getModule());
109+
auto indResults = function.getIndirectResults();
110+
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
111+
auto retVal = retInst->getOperand();
112+
SmallVector<SILValue, 8> dirResults;
113+
if (auto *tupleInst =
114+
dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction()))
115+
dirResults.append(tupleInst->getElements().begin(),
116+
tupleInst->getElements().end());
117+
else
118+
dirResults.push_back(retVal);
119+
unsigned indResIdx = 0, dirResIdx = 0;
120+
for (auto &resInfo : convs.getResults())
121+
results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
122+
: indResults[indResIdx++]);
123+
// Treat `inout` parameters as semantic results.
124+
// Append `inout` parameters after formal results.
125+
for (auto i : range(convs.getNumParameters())) {
126+
auto paramInfo = convs.getParameters()[i];
127+
if (!paramInfo.isIndirectMutating())
128+
continue;
129+
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
130+
results.push_back(argument);
131+
}
132+
}
133+
134+
void collectAllDirectResultsInTypeOrder(SILFunction &function,
135+
SmallVectorImpl<SILValue> &results) {
136+
SILFunctionConventions convs(function.getLoweredFunctionType(),
137+
function.getModule());
138+
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
139+
auto retVal = retInst->getOperand();
140+
if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
141+
results.append(tupleInst->getElements().begin(),
142+
tupleInst->getElements().end());
143+
else
144+
results.push_back(retVal);
145+
}
146+
30147
void collectAllActualResultsInTypeOrder(
31148
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
32149
SmallVectorImpl<SILValue> &results) {
@@ -39,6 +156,73 @@ void collectAllActualResultsInTypeOrder(
39156
}
40157
}
41158

159+
void collectMinimalIndicesForFunctionCall(
160+
ApplyInst *ai, SILAutoDiffIndices parentIndices,
161+
const DifferentiableActivityInfo &activityInfo,
162+
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
163+
SmallVectorImpl<unsigned> &resultIndices) {
164+
auto calleeFnTy = ai->getSubstCalleeType();
165+
auto calleeConvs = ai->getSubstCalleeConv();
166+
// Parameter indices are indices (in the callee type signature) of parameter
167+
// arguments that are varied or are arguments.
168+
// Record all parameter indices in type order.
169+
unsigned currentParamIdx = 0;
170+
for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
171+
if (activityInfo.isActive(applyArg, parentIndices))
172+
paramIndices.push_back(currentParamIdx);
173+
++currentParamIdx;
174+
}
175+
// Result indices are indices (in the callee type signature) of results that
176+
// are useful.
177+
SmallVector<SILValue, 8> directResults;
178+
forEachApplyDirectResult(ai, [&](SILValue directResult) {
179+
directResults.push_back(directResult);
180+
});
181+
auto indirectResults = ai->getIndirectSILResults();
182+
// Record all results and result indices in type order.
183+
results.reserve(calleeFnTy->getNumResults());
184+
unsigned dirResIdx = 0;
185+
unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult();
186+
for (auto &resAndIdx : enumerate(calleeConvs.getResults())) {
187+
auto &res = resAndIdx.value();
188+
unsigned idx = resAndIdx.index();
189+
if (res.isFormalDirect()) {
190+
results.push_back(directResults[dirResIdx]);
191+
if (auto dirRes = directResults[dirResIdx])
192+
if (dirRes && activityInfo.isActive(dirRes, parentIndices))
193+
resultIndices.push_back(idx);
194+
++dirResIdx;
195+
} else {
196+
results.push_back(indirectResults[indResIdx]);
197+
if (activityInfo.isActive(indirectResults[indResIdx], parentIndices))
198+
resultIndices.push_back(idx);
199+
++indResIdx;
200+
}
201+
}
202+
// Record all `inout` parameters as results.
203+
auto inoutParamResultIndex = calleeFnTy->getNumResults();
204+
for (auto &paramAndIdx : enumerate(calleeConvs.getParameters())) {
205+
auto &param = paramAndIdx.value();
206+
if (!param.isIndirectMutating())
207+
continue;
208+
unsigned idx = paramAndIdx.index();
209+
auto inoutArg = ai->getArgument(idx);
210+
results.push_back(inoutArg);
211+
resultIndices.push_back(inoutParamResultIndex++);
212+
}
213+
// Make sure the function call has active results.
214+
auto numResults = calleeFnTy->getNumResults() +
215+
calleeFnTy->getNumIndirectMutatingParameters();
216+
assert(results.size() == numResults);
217+
assert(llvm::any_of(results, [&](SILValue result) {
218+
return activityInfo.isActive(result, parentIndices);
219+
}));
220+
}
221+
222+
//===----------------------------------------------------------------------===//
223+
// Code emission utilities
224+
//===----------------------------------------------------------------------===//
225+
42226
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
43227
SILLocation loc) {
44228
if (elements.size() == 1)

0 commit comments

Comments
 (0)