Skip to content

Commit 9e28e0a

Browse files
committed
[AutoDiff upstream] Add AdjointValue.
Add `AdjointValue`: a symbolic representation for adjoint values enabling efficient differentiation by avoiding zero materialization.
1 parent 55ac2c0 commit 9e28e0a

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//===--- AdjointValue.h - Helper class for differentiation ----*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// AdjointValue - a symbolic representation for adjoint values enabling
14+
// efficient differentiation by avoiding zero materialization.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
19+
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
20+
21+
#include "swift/SIL/SILValue.h"
22+
#include "llvm/ADT/ArrayRef.h"
23+
24+
namespace swift {
25+
namespace autodiff {
26+
27+
enum AdjointValueKind {
28+
/// An empty adjoint, i.e. zero. This case exists due to its special
29+
/// mathematical properties: `0 + x = x`. This is a guaranteed optimization
30+
/// when we combine a zero adjoint with another (e.g. differentiating a
31+
/// fanout).
32+
Zero,
33+
34+
/// An aggregate of adjoint values: a struct or tuple.
35+
Aggregate,
36+
37+
/// A concrete SIL value.
38+
Concrete,
39+
};
40+
41+
class AdjointValue;
42+
43+
class AdjointValueBase {
44+
friend class AdjointValue;
45+
46+
/// The kind of this adjoint value.
47+
AdjointValueKind kind;
48+
49+
/// The type of this value as if it were materialized as a SIL value.
50+
SILType type;
51+
52+
/// The underlying value.
53+
union Value {
54+
llvm::ArrayRef<AdjointValue> aggregate;
55+
SILValue concrete;
56+
Value(llvm::ArrayRef<AdjointValue> v) : aggregate(v) {}
57+
Value(SILValue v) : concrete(v) {}
58+
Value() {}
59+
} value;
60+
61+
explicit AdjointValueBase(SILType type,
62+
llvm::ArrayRef<AdjointValue> aggregate)
63+
: kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
64+
65+
explicit AdjointValueBase(SILValue v)
66+
: kind(AdjointValueKind::Concrete), type(v->getType()), value(v) {}
67+
68+
explicit AdjointValueBase(SILType type)
69+
: kind(AdjointValueKind::Zero), type(type) {}
70+
};
71+
72+
/// A symbolic adjoint value that is capable of representing zero value 0 and
73+
/// 1, in addition to a materialized SILValue. This is expected to be passed
74+
/// around by value in most cases, as it's two words long.
75+
class AdjointValue final {
76+
77+
private:
78+
/// The kind of this adjoint value.
79+
AdjointValueBase *base;
80+
/*implicit*/ AdjointValue(AdjointValueBase *base = nullptr) : base(base) {}
81+
82+
public:
83+
AdjointValueBase *operator->() const { return base; }
84+
AdjointValueBase &operator*() const { return *base; }
85+
86+
static AdjointValue createConcrete(llvm::BumpPtrAllocator &allocator,
87+
SILValue value) {
88+
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(value);
89+
}
90+
91+
static AdjointValue createZero(llvm::BumpPtrAllocator &allocator,
92+
SILType type) {
93+
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(type);
94+
}
95+
96+
static AdjointValue createAggregate(llvm::BumpPtrAllocator &allocator,
97+
SILType type,
98+
llvm::ArrayRef<AdjointValue> aggregate) {
99+
return new (allocator.Allocate<AdjointValueBase>())
100+
AdjointValueBase(type, aggregate);
101+
}
102+
103+
AdjointValueKind getKind() const { return base->kind; }
104+
SILType getType() const { return base->type; }
105+
CanType getSwiftType() const { return getType().getASTType(); }
106+
107+
NominalTypeDecl *getAnyNominal() const {
108+
return getSwiftType()->getAnyNominal();
109+
}
110+
111+
bool isZero() const { return getKind() == AdjointValueKind::Zero; }
112+
bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
113+
bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }
114+
115+
unsigned getNumAggregateElements() const {
116+
assert(isAggregate());
117+
return base->value.aggregate.size();
118+
}
119+
120+
AdjointValue getAggregateElement(unsigned i) const {
121+
assert(isAggregate());
122+
return base->value.aggregate[i];
123+
}
124+
125+
llvm::ArrayRef<AdjointValue> getAggregateElements() const {
126+
return base->value.aggregate;
127+
}
128+
129+
SILValue getConcreteValue() const {
130+
assert(isConcrete());
131+
return base->value.concrete;
132+
}
133+
134+
void print(llvm::raw_ostream &s) const {
135+
switch (getKind()) {
136+
case AdjointValueKind::Zero:
137+
s << "Zero";
138+
break;
139+
case AdjointValueKind::Aggregate:
140+
s << "Aggregate<";
141+
if (auto *decl =
142+
getType().getASTType()->getStructOrBoundGenericStruct()) {
143+
s << "Struct>(";
144+
interleave(
145+
llvm::zip(decl->getStoredProperties(), base->value.aggregate),
146+
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
147+
s << std::get<0>(elt)->getName() << ": ";
148+
std::get<1>(elt).print(s);
149+
},
150+
[&s] { s << ", "; });
151+
} else if (auto tupleType = getType().getAs<TupleType>()) {
152+
s << "Tuple>(";
153+
interleave(
154+
base->value.aggregate,
155+
[&s](const AdjointValue &elt) { elt.print(s); },
156+
[&s] { s << ", "; });
157+
} else {
158+
llvm_unreachable("Invalid aggregate");
159+
}
160+
s << ')';
161+
break;
162+
case AdjointValueKind::Concrete:
163+
s << "Concrete(" << base->value.concrete << ')';
164+
break;
165+
}
166+
}
167+
};
168+
169+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
170+
const AdjointValue &adjVal) {
171+
adjVal.print(os);
172+
return os;
173+
}
174+
175+
} // end namespace autodiff
176+
} // end namespace swift
177+
178+
#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H

0 commit comments

Comments
 (0)