|
| 1 | +//===--- AdjointValue.h - Helper class for differentiation ----*- C++ -*---===// |
| 2 | +// |
| 3 | +// This source file is part of the Swift.org open source project |
| 4 | +// |
| 5 | +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
| 6 | +// Licensed under Apache License v2.0 with Runtime Library Exception |
| 7 | +// |
| 8 | +// See https://swift.org/LICENSE.txt for license information |
| 9 | +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | +// |
| 13 | +// AdjointValue - a symbolic representation for adjoint values enabling |
| 14 | +// efficient differentiation by avoiding zero materialization. |
| 15 | +// |
| 16 | +//===----------------------------------------------------------------------===// |
| 17 | + |
| 18 | +#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H |
| 19 | +#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H |
| 20 | + |
| 21 | +#include "swift/SIL/SILValue.h" |
| 22 | +#include "llvm/ADT/ArrayRef.h" |
| 23 | + |
| 24 | +namespace swift { |
| 25 | +namespace autodiff { |
| 26 | + |
| 27 | +enum AdjointValueKind { |
| 28 | + /// An empty adjoint, i.e. zero. This case exists due to its special |
| 29 | + /// mathematical properties: `0 + x = x`. This is a guaranteed optimization |
| 30 | + /// when we combine a zero adjoint with another (e.g. differentiating a |
| 31 | + /// fanout). |
| 32 | + Zero, |
| 33 | + |
| 34 | + /// An aggregate of adjoint values: a struct or tuple. |
| 35 | + Aggregate, |
| 36 | + |
| 37 | + /// A concrete SIL value. |
| 38 | + Concrete, |
| 39 | +}; |
| 40 | + |
| 41 | +class AdjointValue; |
| 42 | + |
| 43 | +class AdjointValueBase { |
| 44 | + friend class AdjointValue; |
| 45 | + |
| 46 | + /// The kind of this adjoint value. |
| 47 | + AdjointValueKind kind; |
| 48 | + |
| 49 | + /// The type of this value as if it were materialized as a SIL value. |
| 50 | + SILType type; |
| 51 | + |
| 52 | + /// The underlying value. |
| 53 | + union Value { |
| 54 | + llvm::ArrayRef<AdjointValue> aggregate; |
| 55 | + SILValue concrete; |
| 56 | + Value(llvm::ArrayRef<AdjointValue> v) : aggregate(v) {} |
| 57 | + Value(SILValue v) : concrete(v) {} |
| 58 | + Value() {} |
| 59 | + } value; |
| 60 | + |
| 61 | + explicit AdjointValueBase(SILType type, |
| 62 | + llvm::ArrayRef<AdjointValue> aggregate) |
| 63 | + : kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {} |
| 64 | + |
| 65 | + explicit AdjointValueBase(SILValue v) |
| 66 | + : kind(AdjointValueKind::Concrete), type(v->getType()), value(v) {} |
| 67 | + |
| 68 | + explicit AdjointValueBase(SILType type) |
| 69 | + : kind(AdjointValueKind::Zero), type(type) {} |
| 70 | +}; |
| 71 | + |
| 72 | +/// A symbolic adjoint value that is capable of representing zero value 0 and |
| 73 | +/// 1, in addition to a materialized SILValue. This is expected to be passed |
| 74 | +/// around by value in most cases, as it's two words long. |
| 75 | +class AdjointValue final { |
| 76 | + |
| 77 | +private: |
| 78 | + /// The kind of this adjoint value. |
| 79 | + AdjointValueBase *base; |
| 80 | + /*implicit*/ AdjointValue(AdjointValueBase *base = nullptr) : base(base) {} |
| 81 | + |
| 82 | +public: |
| 83 | + AdjointValueBase *operator->() const { return base; } |
| 84 | + AdjointValueBase &operator*() const { return *base; } |
| 85 | + |
| 86 | + static AdjointValue createConcrete(llvm::BumpPtrAllocator &allocator, |
| 87 | + SILValue value) { |
| 88 | + return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(value); |
| 89 | + } |
| 90 | + |
| 91 | + static AdjointValue createZero(llvm::BumpPtrAllocator &allocator, |
| 92 | + SILType type) { |
| 93 | + return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(type); |
| 94 | + } |
| 95 | + |
| 96 | + static AdjointValue createAggregate(llvm::BumpPtrAllocator &allocator, |
| 97 | + SILType type, |
| 98 | + llvm::ArrayRef<AdjointValue> aggregate) { |
| 99 | + return new (allocator.Allocate<AdjointValueBase>()) |
| 100 | + AdjointValueBase(type, aggregate); |
| 101 | + } |
| 102 | + |
| 103 | + AdjointValueKind getKind() const { return base->kind; } |
| 104 | + SILType getType() const { return base->type; } |
| 105 | + CanType getSwiftType() const { return getType().getASTType(); } |
| 106 | + |
| 107 | + NominalTypeDecl *getAnyNominal() const { |
| 108 | + return getSwiftType()->getAnyNominal(); |
| 109 | + } |
| 110 | + |
| 111 | + bool isZero() const { return getKind() == AdjointValueKind::Zero; } |
| 112 | + bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; } |
| 113 | + bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; } |
| 114 | + |
| 115 | + unsigned getNumAggregateElements() const { |
| 116 | + assert(isAggregate()); |
| 117 | + return base->value.aggregate.size(); |
| 118 | + } |
| 119 | + |
| 120 | + AdjointValue getAggregateElement(unsigned i) const { |
| 121 | + assert(isAggregate()); |
| 122 | + return base->value.aggregate[i]; |
| 123 | + } |
| 124 | + |
| 125 | + llvm::ArrayRef<AdjointValue> getAggregateElements() const { |
| 126 | + return base->value.aggregate; |
| 127 | + } |
| 128 | + |
| 129 | + SILValue getConcreteValue() const { |
| 130 | + assert(isConcrete()); |
| 131 | + return base->value.concrete; |
| 132 | + } |
| 133 | + |
| 134 | + void print(llvm::raw_ostream &s) const { |
| 135 | + switch (getKind()) { |
| 136 | + case AdjointValueKind::Zero: |
| 137 | + s << "Zero"; |
| 138 | + break; |
| 139 | + case AdjointValueKind::Aggregate: |
| 140 | + s << "Aggregate<"; |
| 141 | + if (auto *decl = |
| 142 | + getType().getASTType()->getStructOrBoundGenericStruct()) { |
| 143 | + s << "Struct>("; |
| 144 | + interleave( |
| 145 | + llvm::zip(decl->getStoredProperties(), base->value.aggregate), |
| 146 | + [&s](std::tuple<VarDecl *, const AdjointValue &> elt) { |
| 147 | + s << std::get<0>(elt)->getName() << ": "; |
| 148 | + std::get<1>(elt).print(s); |
| 149 | + }, |
| 150 | + [&s] { s << ", "; }); |
| 151 | + } else if (auto tupleType = getType().getAs<TupleType>()) { |
| 152 | + s << "Tuple>("; |
| 153 | + interleave( |
| 154 | + base->value.aggregate, |
| 155 | + [&s](const AdjointValue &elt) { elt.print(s); }, |
| 156 | + [&s] { s << ", "; }); |
| 157 | + } else { |
| 158 | + llvm_unreachable("Invalid aggregate"); |
| 159 | + } |
| 160 | + s << ')'; |
| 161 | + break; |
| 162 | + case AdjointValueKind::Concrete: |
| 163 | + s << "Concrete(" << base->value.concrete << ')'; |
| 164 | + break; |
| 165 | + } |
| 166 | + } |
| 167 | +}; |
| 168 | + |
| 169 | +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
| 170 | + const AdjointValue &adjVal) { |
| 171 | + adjVal.print(os); |
| 172 | + return os; |
| 173 | +} |
| 174 | + |
| 175 | +} // end namespace autodiff |
| 176 | +} // end namespace swift |
| 177 | + |
| 178 | +#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H |
0 commit comments