15
15
//
16
16
// ===----------------------------------------------------------------------===//
17
17
18
- #include " swift/Basic/STLExtras.h"
19
18
#define DEBUG_TYPE " differentiation"
20
19
21
20
#include " swift/SILOptimizer/Differentiation/PullbackCloner.h"
31
30
#include " swift/AST/PropertyWrappers.h"
32
31
#include " swift/AST/TypeCheckRequests.h"
33
32
#include " swift/Basic/Assertions.h"
33
+ #include " swift/Basic/STLExtras.h"
34
34
#include " swift/SIL/ApplySite.h"
35
35
#include " swift/SIL/InstructionUtils.h"
36
36
#include " swift/SIL/Projection.h"
@@ -131,6 +131,10 @@ class PullbackCloner::Implementation final
131
131
// / Stack buffers allocated for storing local adjoint values.
132
132
SmallVector<AllocStackInst *, 64 > functionLocalAllocations;
133
133
134
+ // / Copies created to deal with destructive enum operations
135
+ // / (unchecked_take_enum_addr)
136
+ llvm::SmallDenseMap<InitEnumDataAddrInst*, SILValue> enumDataAdjCopies;
137
+
134
138
// / A set used to remember local allocations that were destroyed.
135
139
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
136
140
@@ -1858,7 +1862,7 @@ class PullbackCloner::Implementation final
1858
1862
// / Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
1859
1863
// / instructions.
1860
1864
// /
1861
- // / Original: y = init_enum_data_addr x
1865
+ // / Original: x = init_enum_data_addr y : $*Enum, #Enum.Case
1862
1866
// / inject_enum_addr y
1863
1867
// /
1864
1868
// / Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
@@ -1879,6 +1883,10 @@ class PullbackCloner::Implementation final
1879
1883
return ;
1880
1884
}
1881
1885
1886
+ // No associated value => no adjoint to propagate
1887
+ if (!inject->getElement ()->hasAssociatedValues ())
1888
+ return ;
1889
+
1882
1890
InitEnumDataAddrInst *origData = nullptr ;
1883
1891
for (auto use : origEnum->getUses ()) {
1884
1892
if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser ())) {
@@ -1900,9 +1908,9 @@ class PullbackCloner::Implementation final
1900
1908
}
1901
1909
}
1902
1910
1903
- SILValue adjStruct = getAdjointBuffer (bb, origEnum);
1911
+ SILValue adjDest = getAdjointBuffer (bb, origEnum);
1904
1912
StructDecl *adjStructDecl =
1905
- adjStruct ->getType ().getStructOrBoundGenericStruct ();
1913
+ adjDest ->getType ().getStructOrBoundGenericStruct ();
1906
1914
1907
1915
VarDecl *adjOptVar = nullptr ;
1908
1916
if (adjStructDecl) {
@@ -1922,35 +1930,35 @@ class PullbackCloner::Implementation final
1922
1930
1923
1931
SILLocation loc = origData->getLoc ();
1924
1932
StructElementAddrInst *adjOpt =
1925
- builder.createStructElementAddr (loc, adjStruct , adjOptVar);
1933
+ builder.createStructElementAddr (loc, adjDest , adjOptVar);
1926
1934
1927
1935
// unchecked_take_enum_data_addr is destructive, so copy
1928
1936
// Optional<T.TangentVector> to a new alloca.
1929
1937
AllocStackInst *adjOptCopy =
1930
1938
createFunctionLocalAllocation (adjOpt->getType (), loc);
1931
1939
builder.createCopyAddr (loc, adjOpt, adjOptCopy, IsNotTake,
1932
1940
IsInitialization);
1941
+ // The Optional copy is invalidated, do not attempt to destroy it at the end
1942
+ // of the pullback. The value returned from unchecked_take_enum_data_addr is
1943
+ // destroyed in visitInitEnumDataAddrInst.
1944
+ auto [_, inserted] = enumDataAdjCopies.try_emplace (origData, adjOptCopy);
1945
+ assert (inserted && " expected single buffer" );
1933
1946
1934
1947
EnumElementDecl *someElemDecl = getASTContext ().getOptionalSomeDecl ();
1935
1948
UncheckedTakeEnumDataAddrInst *adjData =
1936
1949
builder.createUncheckedTakeEnumDataAddr (loc, adjOptCopy, someElemDecl);
1937
1950
1938
- setAdjointBuffer (bb, origData, adjData);
1939
-
1940
- // The Optional copy is invalidated, do not attempt to destroy it at the end
1941
- // of the pullback. The value returned from unchecked_take_enum_data_addr is
1942
- // destroyed in visitInitEnumDataAddrInst.
1943
- destroyedLocalAllocations.insert (adjOptCopy);
1951
+ addToAdjointBuffer (bb, origData, adjData, loc);
1944
1952
}
1945
1953
1946
1954
// / Handle `init_enum_data_addr` instruction.
1947
1955
// / Destroy the value returned from `unchecked_take_enum_data_addr`.
1948
1956
void visitInitEnumDataAddrInst (InitEnumDataAddrInst *init) {
1949
- auto bufIt = bufferMap. find ({ init-> getParent (), SILValue (init)} );
1950
- if (bufIt == bufferMap. end ())
1951
- return ;
1952
- SILValue adjData = bufIt-> second ;
1953
- builder. emitDestroyAddr (init-> getLoc (), adjData );
1957
+ SILValue adjOptCopy = enumDataAdjCopies. at ( init);
1958
+
1959
+ builder. emitDestroyAddr (init-> getLoc (), adjOptCopy) ;
1960
+ destroyedLocalAllocations. insert (adjOptCopy) ;
1961
+ enumDataAdjCopies. erase (init);
1954
1962
}
1955
1963
1956
1964
// / Handle `unchecked_ref_cast` instruction.
@@ -2567,6 +2575,12 @@ bool PullbackCloner::Implementation::run() {
2567
2575
}
2568
2576
}
2569
2577
}
2578
+ // Ensure all enum adjoint copeis have been cleaned up
2579
+ for (const auto &enumData : enumDataAdjCopies) {
2580
+ leakFound = true ;
2581
+ getADDebugStream () << " Found leaked temporary:\n " << enumData.second ;
2582
+ }
2583
+
2570
2584
// Ensure all local allocations have been cleaned up.
2571
2585
for (auto localAlloc : functionLocalAllocations) {
2572
2586
if (!destroyedLocalAllocations.count (localAlloc)) {
0 commit comments