Skip to content

Commit 8d479f1

Browse files
committed
[autodiff] Change getTangentStoredProperty() to use a Projection instead of FieldIndexCacheBase.
This is NFCI. THis is in preparation for making FieldIndexCacheBase a templated subclass.
1 parent 0d728ee commit 8d479f1

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "swift/AST/SemanticAttrs.h"
2323
#include "swift/SIL/SILDifferentiabilityWitness.h"
2424
#include "swift/SIL/SILFunction.h"
25+
#include "swift/SIL/Projection.h"
2526
#include "swift/SIL/SILModule.h"
2627
#include "swift/SIL/TypeSubstCloner.h"
2728
#include "swift/SILOptimizer/Analysis/ArraySemantic.h"
@@ -166,8 +167,11 @@ VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
166167
/// Returns the tangent stored property of the original stored property
167168
/// referenced by the given projection instruction with the given base type.
168169
/// On error, emits diagnostic and returns nullptr.
170+
///
171+
/// NOTE: Asserts if \p projectionInst is not one of: struct_extract,
172+
/// struct_element_addr, or ref_element_addr.
169173
VarDecl *getTangentStoredProperty(ADContext &context,
170-
FieldIndexCacheBase *projectionInst,
174+
SingleValueInstruction *projectionInst,
171175
CanType baseType,
172176
DifferentiationInvoker invoker);
173177

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,16 @@ VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
329329
}
330330

331331
VarDecl *getTangentStoredProperty(ADContext &context,
332-
FieldIndexCacheBase *projectionInst,
332+
SingleValueInstruction *projectionInst,
333333
CanType baseType,
334334
DifferentiationInvoker invoker) {
335335
assert(isa<StructExtractInst>(projectionInst) ||
336336
isa<StructElementAddrInst>(projectionInst) ||
337337
isa<RefElementAddrInst>(projectionInst));
338+
Projection proj(projectionInst);
338339
auto loc = getValidLocation(projectionInst);
339-
return getTangentStoredProperty(context, projectionInst->getField(), baseType,
340+
auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType());
341+
return getTangentStoredProperty(context, field, baseType,
340342
loc, invoker);
341343
}
342344

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ class JVPCloner::Implementation final
10431043
auto structType =
10441044
remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType();
10451045
auto *tanField =
1046-
getTangentStoredProperty(context, sei, structType, invoker);
1046+
getTangentStoredProperty(context, sei, structType, invoker);
10471047
if (!tanField) {
10481048
errorOccurred = true;
10491049
return;
@@ -1074,7 +1074,7 @@ class JVPCloner::Implementation final
10741074
auto structType =
10751075
remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType();
10761076
auto *tanField =
1077-
getTangentStoredProperty(context, seai, structType, invoker);
1077+
getTangentStoredProperty(context, seai, structType, invoker);
10781078
if (!tanField) {
10791079
errorOccurred = true;
10801080
return;

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,9 @@ bool PullbackCloner::Implementation::run() {
18321832
}
18331833
}
18341834
// Diagnose unsupported stored property projections.
1835-
if (auto *inst = dyn_cast<FieldIndexCacheBase>(v)) {
1835+
if (isa<StructExtractInst>(v) || isa<RefElementAddrInst>(v) ||
1836+
isa<StructElementAddrInst>(v)) {
1837+
auto *inst = cast<SingleValueInstruction>(v);
18361838
assert(inst->getNumOperands() == 1);
18371839
auto baseType = remapType(inst->getOperand(0)->getType()).getASTType();
18381840
if (!getTangentStoredProperty(getContext(), inst, baseType,

0 commit comments

Comments
 (0)