Skip to content

Commit 2723ca4

Browse files
Merge pull request #4455 from swiftwasm/main
[pull] swiftwasm from main
2 parents 041be0b + d52dddc commit 2723ca4

File tree

6 files changed

+314
-177
lines changed

6 files changed

+314
-177
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,10 @@ class PullbackCloner::Implementation final
549549
if (auto adjProj = getAdjointProjection(origBB, originalValue))
550550
return (bufferMap[{origBB, originalValue}] = adjProj);
551551

552+
LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for"
553+
<< originalValue
554+
<< "in bb" << origBB->getDebugID() << '\n');
555+
552556
auto bufType = getRemappedTangentType(originalValue->getType());
553557
// Set insertion point for local allocation builder: before the last local
554558
// allocation, or at the start of the pullback function's entry if no local
@@ -583,6 +587,12 @@ class PullbackCloner::Implementation final
583587
assert(originalValue->getFunction() == &getOriginal());
584588
assert(rhsAddress->getFunction() == &getPullback());
585589
auto adjointBuffer = getAdjointBuffer(origBB, originalValue);
590+
591+
LLVM_DEBUG(getADDebugStream() << "Adding"
592+
<< rhsAddress << "to adjoint of "
593+
<< originalValue
594+
<< "in bb" << origBB->getDebugID() << '\n');
595+
586596
builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress);
587597
}
588598

@@ -2340,7 +2350,9 @@ SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
23402350
for (auto activeValue : predBBActiveValues) {
23412351
LLVM_DEBUG(getADDebugStream()
23422352
<< "Propagating adjoint of active value " << activeValue
2343-
<< " to predecessors' pullback blocks\n");
2353+
<< "from bb" << origBB->getDebugID()
2354+
<< " to predecessors' (bb" << origPredBB->getDebugID()
2355+
<< ") pullback blocks\n");
23442356
switch (getTangentValueCategory(activeValue)) {
23452357
case SILValueCategory::Object: {
23462358
auto activeValueAdj = getAdjointValue(origBB, activeValue);
@@ -2511,14 +2523,13 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
25112523
case SILValueCategory::Address: {
25122524
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
25132525
for (auto pair : incomingValues) {
2514-
auto *predBB = std::get<0>(pair);
25152526
auto incomingValue = std::get<1>(pair);
25162527
// Handle `switch_enum` on `Optional`.
25172528
auto termInst = bbArg->getSingleTerminator();
25182529
if (isSwitchEnumInstOnOptional(termInst))
25192530
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
25202531
else
2521-
addToAdjointBuffer(predBB, incomingValue, bbArgAdjBuf, pbLoc);
2532+
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
25222533
}
25232534
break;
25242535
}

lib/Sema/CSDiagnostics.cpp

Lines changed: 36 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3970,8 +3970,11 @@ bool InvalidMemberRefOnExistential::diagnoseAsError() {
39703970

39713971
// If the base expression is a reference to a function or subscript
39723972
// parameter, offer a fixit that replaces the existential parameter type with
3973-
// its generic equivalent, e.g. func foo(p: any P) → func foo<T: P>(p: T).
3974-
// FIXME: Add an option to use 'some' vs. an explicit generic parameter.
3973+
// its generic equivalent, e.g. func foo(p: any P) → func foo(p: some P).
3974+
// Replacing 'any' with 'some' allows the code to compile without further
3975+
// changes, such as naming an explicit type parameter, and is future-proofed
3976+
// for same-type requirements on primary associated types instead of needing
3977+
// a where clause.
39753978

39763979
if (!PD || !PD->getDeclContext()->getAsDecl())
39773980
return true;
@@ -4020,70 +4023,42 @@ bool InvalidMemberRefOnExistential::diagnoseAsError() {
40204023
if (PD->isInOut())
40214024
return true;
40224025

4023-
constexpr StringRef GPNamePlaceholder = "<#generic parameter name#>";
4024-
SourceRange TyReplacementRange;
4025-
SourceRange RemoveAnyRange;
4026-
SourceLoc GPDeclLoc;
4027-
std::string GPDeclStr;
4028-
{
4029-
llvm::raw_string_ostream OS(GPDeclStr);
4030-
auto *const GC = PD->getDeclContext()->getAsDecl()->getAsGenericContext();
4031-
if (GC->getParsedGenericParams()) {
4032-
GPDeclLoc = GC->getParsedGenericParams()->getRAngleLoc();
4033-
OS << ", ";
4034-
} else {
4035-
GPDeclLoc =
4036-
isa<AbstractFunctionDecl>(GC)
4037-
? cast<AbstractFunctionDecl>(GC)->getParameters()->getLParenLoc()
4038-
: cast<SubscriptDecl>(GC)->getIndices()->getLParenLoc();
4039-
OS << "<";
4040-
}
4041-
OS << GPNamePlaceholder << ": ";
4042-
4043-
auto *TR = PD->getTypeRepr()->getWithoutParens();
4044-
if (auto *STR = dyn_cast<SpecifierTypeRepr>(TR)) {
4045-
TR = STR->getBase()->getWithoutParens();
4046-
}
4047-
if (auto *ETR = dyn_cast<ExistentialTypeRepr>(TR)) {
4048-
TR = ETR->getConstraint();
4049-
RemoveAnyRange = SourceRange(ETR->getAnyLoc(), TR->getStartLoc());
4050-
TR = TR->getWithoutParens();
4051-
}
4052-
if (auto *MTR = dyn_cast<MetatypeTypeRepr>(TR)) {
4053-
TR = MTR->getBase();
4054-
4055-
// (P & Q).Type -> T.Type
4056-
// (P).Type -> (T).Type
4057-
// ((P & Q)).Type -> ((T)).Type
4058-
if (auto *TTR = dyn_cast<TupleTypeRepr>(TR)) {
4059-
assert(TTR->isParenType());
4060-
if (!isa<CompositionTypeRepr>(TTR->getElementType(0))) {
4061-
TR = TR->getWithoutParens();
4062-
}
4063-
}
4064-
}
4065-
TyReplacementRange = TR->getSourceRange();
4066-
4067-
// Strip any remaining parentheses and print the conformance constraint.
4068-
TR->getWithoutParens()->print(OS);
4069-
4070-
if (!GC->getParsedGenericParams()) {
4071-
OS << ">";
4072-
}
4026+
auto *typeRepr = PD->getTypeRepr()->getWithoutParens();
4027+
if (auto *STR = dyn_cast<SpecifierTypeRepr>(typeRepr)) {
4028+
typeRepr = STR->getBase()->getWithoutParens();
40734029
}
40744030

4075-
// First, replace the constraint type with the generic parameter type
4076-
// placeholder.
4077-
Diag.fixItReplace(TyReplacementRange, GPNamePlaceholder);
4031+
SourceRange anyRange;
4032+
TypeRepr *constraintRepr = typeRepr;
4033+
if (auto *existentialRepr = dyn_cast<ExistentialTypeRepr>(typeRepr)) {
4034+
constraintRepr = existentialRepr->getConstraint()->getWithoutParens();
4035+
auto anyStart = existentialRepr->getAnyLoc();
4036+
auto anyEnd = existentialRepr->getConstraint()->getStartLoc();
4037+
anyRange = SourceRange(anyStart, anyEnd);
4038+
}
40784039

4079-
// Remove 'any' if needed, using a character-based removal to pick up
4080-
// whitespaces between it and its constraint repr.
4081-
if (RemoveAnyRange.isValid()) {
4082-
Diag.fixItRemoveChars(RemoveAnyRange.Start, RemoveAnyRange.End);
4040+
bool needsParens = false;
4041+
while (auto *metatype = dyn_cast<MetatypeTypeRepr>(constraintRepr)) {
4042+
// The generic equivalent of 'any P.Type' is '(some P).Type'
4043+
constraintRepr = metatype->getBase()->getWithoutParens();
4044+
if (isa<SimpleIdentTypeRepr>(constraintRepr))
4045+
needsParens = !isa<TupleTypeRepr>(metatype->getBase());
40834046
}
40844047

4085-
// Finally, insert the generic parameter declaration.
4086-
Diag.fixItInsert(GPDeclLoc, GPDeclStr);
4048+
std::string fix;
4049+
llvm::raw_string_ostream OS(fix);
4050+
if (needsParens)
4051+
OS << "(";
4052+
OS << "some ";
4053+
constraintRepr->print(OS);
4054+
if (needsParens)
4055+
OS << ")";
4056+
4057+
// When removing 'any', use a character-based removal to pick up
4058+
// whitespaces between it and its constraint repr.
4059+
Diag
4060+
.fixItReplace(constraintRepr->getSourceRange(), fix)
4061+
.fixItRemoveChars(anyRange.Start, anyRange.End);
40874062

40884063
return true;
40894064
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import _Differentiation
5+
import StdlibUnittest
6+
7+
var PullbackTests = TestSuite("Pullback")
8+
9+
extension Dictionary: Differentiable where Value: Differentiable {
10+
public typealias TangentVector = [Key: Value.TangentVector]
11+
public mutating func move(by direction: TangentVector) {
12+
for (componentKey, componentDirection) in direction {
13+
func fatalMissingComponent() -> Value {
14+
fatalError("missing component \(componentKey) in moved Dictionary")
15+
}
16+
self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
17+
}
18+
}
19+
20+
public var zeroTangentVectorInitializer: () -> TangentVector {
21+
let listOfKeys = self.keys // capturing only what's needed, not the entire self, in order to not waste memory
22+
func initializer() -> Self.TangentVector {
23+
return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
24+
}
25+
return initializer
26+
}
27+
}
28+
29+
extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
30+
public static func + (_ lhs: Self, _ rhs: Self) -> Self {
31+
return lhs.merging(rhs, uniquingKeysWith: +)
32+
}
33+
34+
public static func - (_ lhs: Self, _ rhs: Self) -> Self {
35+
return lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
36+
}
37+
38+
public static var zero: Self { [:] }
39+
}
40+
41+
extension Dictionary where Value: Differentiable {
42+
// get
43+
@usableFromInline
44+
@derivative(of: subscript(_:))
45+
func vjpSubscriptGet(key: Key) -> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector) {
46+
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
47+
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
48+
return (self[key], { v in
49+
if let value = v.value {
50+
return [key: value]
51+
}
52+
else {
53+
return .zero
54+
}
55+
})
56+
}
57+
}
58+
59+
public extension Dictionary where Value: Differentiable {
60+
@differentiable(reverse)
61+
mutating func set(_ key: Key, to newValue: Value) {
62+
self[key] = newValue
63+
}
64+
65+
@derivative(of: set)
66+
mutating func vjpUpdated(_ key: Key, to newValue: Value) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
67+
self.set(key, to: newValue)
68+
69+
let forwardCount = self.count
70+
let forwardKeys = self.keys // may be heavy to capture all of these, not sure how to do without them though
71+
72+
return ((), { v in
73+
// manual zero tangent initialization
74+
if v.count < forwardCount {
75+
v = Self.TangentVector()
76+
forwardKeys.forEach { v[$0] = .zero }
77+
}
78+
79+
if let dElement = v[key] {
80+
v[key] = .zero
81+
return dElement
82+
}
83+
else { // should this fail?
84+
v[key] = .zero
85+
return .zero
86+
}
87+
})
88+
}
89+
}
90+
91+
92+
PullbackTests.test("ConcreteType") {
93+
func getD(from newValues: [String: Double], at key: String) -> Double? {
94+
if newValues.keys.contains(key) {
95+
return newValues[key]
96+
}
97+
return nil
98+
}
99+
100+
@differentiable(reverse)
101+
func testFunctionD(newValues: [String: Double]) -> Double {
102+
return getD(from: newValues, at: "s1")!
103+
}
104+
105+
expectEqual(pullback(at: ["s1": 1.0], of: testFunctionD)(2), ["s1" : 2.0])
106+
}
107+
108+
PullbackTests.test("GenericType") {
109+
func getG<DataType>(from newValues: [String: DataType], at key: String) -> DataType?
110+
where DataType: Differentiable {
111+
if newValues.keys.contains(key) {
112+
return newValues[key]
113+
}
114+
return nil
115+
}
116+
117+
@differentiable(reverse)
118+
func testFunctionG(newValues: [String: Double]) -> Double {
119+
return getG(from: newValues, at: "s1")!
120+
}
121+
122+
expectEqual(pullback(at: ["s1": 1.0], of: testFunctionG)(2), ["s1" : 2.0])
123+
}
124+
125+
runAllTests()

0 commit comments

Comments
 (0)