Skip to content

Commit a751b07

Browse files
authored
Fixes inject_enum_addr handling: (swiftlang#75459)
- Ensure it really accumulates the adjoint buffer - Handle Optional.none case when there is no value to propagate to Fixes swiftlang#75280
1 parent 030f2b8 commit a751b07

File tree

3 files changed

+51
-24
lines changed

3 files changed

+51
-24
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
//
1616
//===----------------------------------------------------------------------===//
1717

18-
#include "swift/Basic/STLExtras.h"
1918
#define DEBUG_TYPE "differentiation"
2019

2120
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
@@ -31,6 +30,7 @@
3130
#include "swift/AST/PropertyWrappers.h"
3231
#include "swift/AST/TypeCheckRequests.h"
3332
#include "swift/Basic/Assertions.h"
33+
#include "swift/Basic/STLExtras.h"
3434
#include "swift/SIL/ApplySite.h"
3535
#include "swift/SIL/InstructionUtils.h"
3636
#include "swift/SIL/Projection.h"
@@ -131,6 +131,10 @@ class PullbackCloner::Implementation final
131131
/// Stack buffers allocated for storing local adjoint values.
132132
SmallVector<AllocStackInst *, 64> functionLocalAllocations;
133133

134+
/// Copies created to deal with destructive enum operations
135+
/// (unchecked_take_enum_addr)
136+
llvm::SmallDenseMap<InitEnumDataAddrInst*, SILValue> enumDataAdjCopies;
137+
134138
/// A set used to remember local allocations that were destroyed.
135139
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
136140

@@ -1858,7 +1862,7 @@ class PullbackCloner::Implementation final
18581862
/// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
18591863
/// instructions.
18601864
///
1861-
/// Original: y = init_enum_data_addr x
1865+
/// Original: x = init_enum_data_addr y : $*Enum, #Enum.Case
18621866
/// inject_enum_addr y
18631867
///
18641868
/// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
@@ -1879,6 +1883,10 @@ class PullbackCloner::Implementation final
18791883
return;
18801884
}
18811885

1886+
// No associated value => no adjoint to propagate
1887+
if (!inject->getElement()->hasAssociatedValues())
1888+
return;
1889+
18821890
InitEnumDataAddrInst *origData = nullptr;
18831891
for (auto use : origEnum->getUses()) {
18841892
if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser())) {
@@ -1900,9 +1908,9 @@ class PullbackCloner::Implementation final
19001908
}
19011909
}
19021910

1903-
SILValue adjStruct = getAdjointBuffer(bb, origEnum);
1911+
SILValue adjDest = getAdjointBuffer(bb, origEnum);
19041912
StructDecl *adjStructDecl =
1905-
adjStruct->getType().getStructOrBoundGenericStruct();
1913+
adjDest->getType().getStructOrBoundGenericStruct();
19061914

19071915
VarDecl *adjOptVar = nullptr;
19081916
if (adjStructDecl) {
@@ -1922,35 +1930,35 @@ class PullbackCloner::Implementation final
19221930

19231931
SILLocation loc = origData->getLoc();
19241932
StructElementAddrInst *adjOpt =
1925-
builder.createStructElementAddr(loc, adjStruct, adjOptVar);
1933+
builder.createStructElementAddr(loc, adjDest, adjOptVar);
19261934

19271935
// unchecked_take_enum_data_addr is destructive, so copy
19281936
// Optional<T.TangentVector> to a new alloca.
19291937
AllocStackInst *adjOptCopy =
19301938
createFunctionLocalAllocation(adjOpt->getType(), loc);
19311939
builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake,
19321940
IsInitialization);
1941+
// The Optional copy is invalidated, do not attempt to destroy it at the end
1942+
// of the pullback. The value returned from unchecked_take_enum_data_addr is
1943+
// destroyed in visitInitEnumDataAddrInst.
1944+
auto [_, inserted] = enumDataAdjCopies.try_emplace(origData, adjOptCopy);
1945+
assert(inserted && "expected single buffer");
19331946

19341947
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
19351948
UncheckedTakeEnumDataAddrInst *adjData =
19361949
builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl);
19371950

1938-
setAdjointBuffer(bb, origData, adjData);
1939-
1940-
// The Optional copy is invalidated, do not attempt to destroy it at the end
1941-
// of the pullback. The value returned from unchecked_take_enum_data_addr is
1942-
// destroyed in visitInitEnumDataAddrInst.
1943-
destroyedLocalAllocations.insert(adjOptCopy);
1951+
addToAdjointBuffer(bb, origData, adjData, loc);
19441952
}
19451953

19461954
/// Handle `init_enum_data_addr` instruction.
19471955
/// Destroy the value returned from `unchecked_take_enum_data_addr`.
19481956
void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) {
1949-
auto bufIt = bufferMap.find({init->getParent(), SILValue(init)});
1950-
if (bufIt == bufferMap.end())
1951-
return;
1952-
SILValue adjData = bufIt->second;
1953-
builder.emitDestroyAddr(init->getLoc(), adjData);
1957+
SILValue adjOptCopy = enumDataAdjCopies.at(init);
1958+
1959+
builder.emitDestroyAddr(init->getLoc(), adjOptCopy);
1960+
destroyedLocalAllocations.insert(adjOptCopy);
1961+
enumDataAdjCopies.erase(init);
19541962
}
19551963

19561964
/// Handle `unchecked_ref_cast` instruction.
@@ -2567,6 +2575,12 @@ bool PullbackCloner::Implementation::run() {
25672575
}
25682576
}
25692577
}
2578+
// Ensure all enum adjoint copeis have been cleaned up
2579+
for (const auto &enumData : enumDataAdjCopies) {
2580+
leakFound = true;
2581+
getADDebugStream() << "Found leaked temporary:\n" << enumData.second;
2582+
}
2583+
25702584
// Ensure all local allocations have been cleaned up.
25712585
for (auto localAlloc : functionLocalAllocations) {
25722586
if (!destroyedLocalAllocations.count(localAlloc)) {

test/AutoDiff/SILOptimizer/optional_pullback.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ import _Differentiation
77
// CHECK-SAME: (@in_guaranteed Optional<τ_0_0>.TangentVector) -> @out τ_0_0.TangentVector
88
//
99
// CHECK: bb0(%[[RET_TAN:.+]] : $*τ_0_0.TangentVector, %[[OPT_TAN:.+]] : $*Optional<τ_0_0>.TangentVector):
10-
// CHECK: %[[RET_TAN_BUF:.+]] = alloc_stack $τ_0_0.TangentVector
10+
// CHECK: %[[RET_TAN_BUF:.+]] = alloc_stack $τ_0_0.TangentVector, let, name "derivative of 'x'
1111

1212
// CHECK: %[[ZERO1:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
1313
// CHECK: apply %[[ZERO1]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %{{.*}})
14+
// CHECK: %[[ADJ_IN_BB:.+]] = alloc_stack $τ_0_0.TangentVector
1415
//
1516
// CHECK: %[[TAN_VAL_COPY:.+]] = alloc_stack $Optional<τ_0_0.TangentVector>
1617
// CHECK: %[[TAN_BUF:.+]] = alloc_stack $Optional<τ_0_0>.TangentVector
@@ -21,13 +22,12 @@ import _Differentiation
2122
//
2223
// CHECK: %[[TAN_DATA:.+]] = unchecked_take_enum_data_addr %[[TAN_VAL_COPY]] : $*Optional<τ_0_0.TangentVector>, #Optional.some!enumelt
2324
// CHECK: %[[PLUS_EQUAL:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic."+="
24-
// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %[[TAN_DATA]], %{{.*}})
25-
//
26-
// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector
27-
// CHECK: %[[ZERO2:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
28-
// CHECK: apply %[[ZERO2]]<τ_0_0.TangentVector>(%[[TAN_DATA]], %{{.*}})
29-
// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector
30-
//
25+
// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[ADJ_IN_BB]], %[[TAN_DATA]], %{{.*}})
26+
27+
// CHECK: %[[PLUS_EQUAL:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic."+="
28+
// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %[[ADJ_IN_BB]], %{{.*}})
29+
// CHECK: destroy_addr %[[ADJ_IN_BB]] : $*τ_0_0.TangentVector
30+
3131
// CHECK: copy_addr [take] %[[RET_TAN_BUF:.+]] to [init] %[[RET_TAN:.+]]
3232
// CHECK: destroy_addr %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector
3333
// CHECK: dealloc_stack %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// https://github.com/swiftlang/swift/issues/75280
4+
// Ensure we accumulate adjoints properly for inject_enum_addr instructions and
5+
// handle `nil` case (no adjoint value to propagate)
6+
7+
8+
import _Differentiation
9+
@differentiable(reverse) func a<F, A>(_ f: Optional<F>, c: @differentiable(reverse) (F) -> A) -> Optional<A> where F: Differentiable, A: Differentiable
10+
{
11+
guard let f else {return nil}
12+
return c(f)
13+
}

0 commit comments

Comments
 (0)