1919#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
2020
2121#include " swift/AST/Decl.h"
22+ #include " swift/SIL/SILDebugVariable.h"
23+ #include " swift/SIL/SILLocation.h"
2224#include " swift/SIL/SILValue.h"
2325#include " llvm/ADT/ArrayRef.h"
2426#include " llvm/Support/Debug.h"
@@ -38,10 +40,18 @@ enum AdjointValueKind {
3840
3941 // / A concrete SIL value.
4042 Concrete,
43+
44+ // / A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
45+ // / an element adjoint to add to one of its fields. This case exists to avoid
46+ // / eager materialization of a base adjoint upon addition with one of its
47+ // / fields.
48+ AddElement,
4149};
4250
4351class AdjointValue ;
4452
53+ struct AddElementValue ;
54+
4555class AdjointValueBase {
4656 friend class AdjointValue ;
4757
@@ -60,9 +70,13 @@ class AdjointValueBase {
6070 union Value {
6171 unsigned numAggregateElements;
6272 SILValue concrete;
73+ AddElementValue *addElementValue;
74+
6375 Value (unsigned numAggregateElements)
6476 : numAggregateElements (numAggregateElements) {}
6577 Value (SILValue v) : concrete (v) {}
78+ Value (AddElementValue *addElementValue)
79+ : addElementValue (addElementValue) {}
6680 Value () {}
6781 } value;
6882
@@ -86,6 +100,11 @@ class AdjointValueBase {
86100
87101 explicit AdjointValueBase (SILType type, llvm::Optional<DebugInfo> debugInfo)
88102 : kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
103+
104+ explicit AdjointValueBase (SILType type, AddElementValue *addElementValue,
105+ llvm::Optional<DebugInfo> debugInfo)
106+ : kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
107+ value(addElementValue) {}
89108};
90109
91110// / A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
@@ -127,6 +146,14 @@ class AdjointValue final {
127146 return new (buf) AdjointValueBase (type, elements, debugInfo);
128147 }
129148
149+ static AdjointValue
150+ createAddElement (llvm::BumpPtrAllocator &allocator, SILType type,
151+ AddElementValue *addElementValue,
152+ llvm::Optional<DebugInfo> debugInfo = llvm::None) {
153+ auto *buf = allocator.Allocate <AdjointValueBase>();
154+ return new (buf) AdjointValueBase (type, addElementValue, debugInfo);
155+ }
156+
130157 AdjointValueKind getKind () const { return base->kind ; }
131158 SILType getType () const { return base->type ; }
132159 CanType getSwiftType () const { return getType ().getASTType (); }
@@ -140,6 +167,9 @@ class AdjointValue final {
140167 bool isZero () const { return getKind () == AdjointValueKind::Zero; }
141168 bool isAggregate () const { return getKind () == AdjointValueKind::Aggregate; }
142169 bool isConcrete () const { return getKind () == AdjointValueKind::Concrete; }
170+ bool isAddElement () const {
171+ return getKind () == AdjointValueKind::AddElement;
172+ }
143173
144174 unsigned getNumAggregateElements () const {
145175 assert (isAggregate ());
@@ -162,41 +192,77 @@ class AdjointValue final {
162192 return base->value .concrete ;
163193 }
164194
165- void print (llvm::raw_ostream &s) const {
166- switch (getKind ()) {
167- case AdjointValueKind::Zero:
168- s << " Zero[" << getType () << ' ]' ;
169- break ;
170- case AdjointValueKind::Aggregate:
171- s << " Aggregate[" << getType () << " ](" ;
172- if (auto *decl =
173- getType ().getASTType ()->getStructOrBoundGenericStruct ()) {
174- interleave (
175- llvm::zip (decl->getStoredProperties (), getAggregateElements ()),
176- [&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
177- s << std::get<0 >(elt)->getName () << " : " ;
178- std::get<1 >(elt).print (s);
179- },
180- [&s] { s << " , " ; });
181- } else if (getType ().is <TupleType>()) {
182- interleave (
183- getAggregateElements (),
184- [&s](const AdjointValue &elt) { elt.print (s); },
185- [&s] { s << " , " ; });
186- } else {
187- llvm_unreachable (" Invalid aggregate" );
188- }
189- s << ' )' ;
190- break ;
191- case AdjointValueKind::Concrete:
192- s << " Concrete[" << getType () << " ](" << base->value .concrete << ' )' ;
193- break ;
194- }
195+ AddElementValue *getAddElementValue () const {
196+ assert (isAddElement ());
197+ return base->value .addElementValue ;
195198 }
196199
200+ void print (llvm::raw_ostream &s) const ;
201+
197202 SWIFT_DEBUG_DUMP { print (llvm::dbgs ()); };
198203};
199204
205+ // / An abstraction that represents the field locator in
206+ // / an `AddElement` adjoint kind. Depending on the aggregate
207+ // / kind - tuple or struct, of the `baseAdjoint` in an
208+ // / `AddElement` adjoint, the field locator may be an `unsigned int`
209+ // / or a `VarDecl *`.
210+ struct FieldLocator final {
211+ FieldLocator (VarDecl *field) : inner(field) {}
212+ FieldLocator (unsigned int index) : inner(index) {}
213+
214+ friend AddElementValue;
215+
216+ private:
217+ bool isTupleFieldLocator () const {
218+ return std::holds_alternative<unsigned int >(inner);
219+ }
220+
221+ const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
222+ std::true_type{};
223+ const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
224+ std::false_type{};
225+
226+ unsigned int getInner (std::true_type) const {
227+ return std::get<unsigned int >(inner);
228+ }
229+
230+ VarDecl *getInner (std::false_type) const {
231+ return std::get<VarDecl *>(inner);
232+ }
233+
234+ std::variant<unsigned int , VarDecl *> inner;
235+ };
236+
237+ // / The underlying value for an `AddElement` adjoint.
238+ struct AddElementValue final {
239+ AdjointValue baseAdjoint;
240+ AdjointValue eltToAdd;
241+ FieldLocator fieldLocator;
242+
243+ AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
244+ FieldLocator fieldLocator)
245+ : baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
246+ fieldLocator (fieldLocator) {
247+ assert (baseAdjoint.getType ().is <TupleType>() ||
248+ baseAdjoint.getType ().getStructOrBoundGenericStruct () != nullptr );
249+ }
250+
251+ bool isTupleAdjoint () const { return fieldLocator.isTupleFieldLocator (); }
252+
253+ bool isStructAdjoint () const { return !isTupleAdjoint (); }
254+
255+ VarDecl *getFieldDecl () const {
256+ assert (isStructAdjoint ());
257+ return this ->fieldLocator .getInner (FieldLocator::STRUCT_FIELD_LOCATOR_TAG);
258+ }
259+
260+ unsigned int getFieldIndex () const {
261+ assert (isTupleAdjoint ());
262+ return this ->fieldLocator .getInner (FieldLocator::TUPLE_FIELD_LOCATOR_TAG);
263+ }
264+ };
265+
200266inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
201267 const AdjointValue &adjVal) {
202268 adjVal.print (os);
0 commit comments