Skip to content

Commit 77fd034

Browse files
author
marcrasi
authored
Merge pull request swiftlang#30711 from rxwei/differential-operators
[AutoDiff upstream] Add differential operators and some utilities.
2 parents 4faa6dc + 013a66b commit 77fd034

File tree

12 files changed

+652
-40
lines changed

12 files changed

+652
-40
lines changed

include/swift/Demangling/TypeDecoder.h

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,33 +181,53 @@ enum class ImplFunctionRepresentation {
181181
Closure
182182
};
183183

184+
enum class ImplFunctionDifferentiabilityKind {
185+
NonDifferentiable,
186+
Normal,
187+
Linear
188+
};
189+
184190
class ImplFunctionTypeFlags {
185191
unsigned Rep : 3;
186192
unsigned Pseudogeneric : 1;
187193
unsigned Escaping : 1;
194+
unsigned DifferentiabilityKind : 2;
188195

189196
public:
190-
ImplFunctionTypeFlags() : Rep(0), Pseudogeneric(0), Escaping(0) {}
197+
ImplFunctionTypeFlags()
198+
: Rep(0), Pseudogeneric(0), Escaping(0), DifferentiabilityKind(0) {}
191199

192-
ImplFunctionTypeFlags(ImplFunctionRepresentation rep,
193-
bool pseudogeneric, bool noescape)
194-
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape) {}
200+
ImplFunctionTypeFlags(ImplFunctionRepresentation rep, bool pseudogeneric,
201+
bool noescape,
202+
ImplFunctionDifferentiabilityKind diffKind)
203+
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape),
204+
DifferentiabilityKind(unsigned(diffKind)) {}
195205

196206
ImplFunctionTypeFlags
197207
withRepresentation(ImplFunctionRepresentation rep) const {
198-
return ImplFunctionTypeFlags(rep, Pseudogeneric, Escaping);
208+
return ImplFunctionTypeFlags(
209+
rep, Pseudogeneric, Escaping,
210+
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
199211
}
200212

201213
ImplFunctionTypeFlags
202214
withEscaping() const {
203-
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
204-
Pseudogeneric, true);
215+
return ImplFunctionTypeFlags(
216+
ImplFunctionRepresentation(Rep), Pseudogeneric, true,
217+
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
205218
}
206219

207220
ImplFunctionTypeFlags
208221
withPseudogeneric() const {
209-
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
210-
true, Escaping);
222+
return ImplFunctionTypeFlags(
223+
ImplFunctionRepresentation(Rep), true, Escaping,
224+
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
225+
}
226+
227+
ImplFunctionTypeFlags
228+
withDifferentiabilityKind(ImplFunctionDifferentiabilityKind diffKind) const {
229+
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep), Pseudogeneric,
230+
Escaping, diffKind);
211231
}
212232

213233
ImplFunctionRepresentation getRepresentation() const {
@@ -217,6 +237,10 @@ class ImplFunctionTypeFlags {
217237
bool isEscaping() const { return Escaping; }
218238

219239
bool isPseudogeneric() const { return Pseudogeneric; }
240+
241+
ImplFunctionDifferentiabilityKind getDifferentiabilityKind() const {
242+
return ImplFunctionDifferentiabilityKind(DifferentiabilityKind);
243+
}
220244
};
221245

222246
#if SWIFT_OBJC_INTEROP
@@ -582,6 +606,14 @@ class TypeDecoder {
582606
flags =
583607
flags.withRepresentation(ImplFunctionRepresentation::Block);
584608
}
609+
} else if (child->getKind() == NodeKind::ImplDifferentiable) {
610+
flags = flags.withDifferentiabilityKind(
611+
ImplFunctionDifferentiabilityKind::Normal);
612+
} else if (child->getKind() == NodeKind::ImplLinear) {
613+
flags = flags.withDifferentiabilityKind(
614+
ImplFunctionDifferentiabilityKind::Linear);
615+
} else if (child->getKind() == NodeKind::ImplEscaping) {
616+
flags = flags.withEscaping();
585617
} else if (child->getKind() == NodeKind::ImplEscaping) {
586618
flags = flags.withEscaping();
587619
} else if (child->getKind() == NodeKind::ImplParameter) {

lib/AST/ASTDemangler.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,11 +500,23 @@ Type ASTBuilder::createImplFunctionType(
500500
break;
501501
}
502502

503+
DifferentiabilityKind diffKind;
504+
switch (flags.getDifferentiabilityKind()) {
505+
case ImplFunctionDifferentiabilityKind::NonDifferentiable:
506+
diffKind = DifferentiabilityKind::NonDifferentiable;
507+
break;
508+
case ImplFunctionDifferentiabilityKind::Normal:
509+
diffKind = DifferentiabilityKind::Normal;
510+
break;
511+
case ImplFunctionDifferentiabilityKind::Linear:
512+
diffKind = DifferentiabilityKind::Linear;
513+
break;
514+
}
515+
503516
// TODO: [store-sil-clang-function-type]
504-
auto einfo = SILFunctionType::ExtInfo(
505-
representation, flags.isPseudogeneric(), !flags.isEscaping(),
506-
DifferentiabilityKind::NonDifferentiable,
507-
/*clangFunctionType*/ nullptr);
517+
auto einfo = SILFunctionType::ExtInfo(representation, flags.isPseudogeneric(),
518+
!flags.isEscaping(), diffKind,
519+
/*clangFunctionType*/ nullptr);
508520

509521
llvm::SmallVector<SILParameterInfo, 8> funcParams;
510522
llvm::SmallVector<SILYieldInfo, 8> funcYields;

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,14 +1106,9 @@ static ManagedValue emitBuiltinAutoDiffApplyTransposeFunction(
11061106
origFnArgVals.push_back(arg.getValue());
11071107

11081108
// Get the transpose function.
1109-
// TODO(TF-1142): Create a linear_function_extract instead of an undef.
1110-
auto fnTy = origFnVal->getType().castTo<SILFunctionType>();
1111-
auto transposeFnType =
1112-
fnTy->getWithoutDifferentiability()->getAutoDiffTransposeFunctionType(
1113-
fnTy->getDifferentiabilityParameterIndices(), SGF.SGM.M.Types,
1114-
LookUpConformanceInModule(SGF.SGM.M.getSwiftModule()));
1115-
SILValue transposeFn =
1116-
SILUndef::get(SILType::getPrimitiveObjectType(transposeFnType), SGF.F);
1109+
SILValue transposeFn = SGF.B.createLinearFunctionExtract(
1110+
loc, LinearDifferentiableFunctionTypeComponent::Transpose, origFnVal);
1111+
auto transposeFnType = transposeFn->getType().castTo<SILFunctionType>();
11171112
auto transposeFnUnsubstType =
11181113
transposeFnType->getUnsubstitutedType(SGF.getModule());
11191114
if (transposeFnType != transposeFnUnsubstType) {
@@ -1204,19 +1199,16 @@ static ManagedValue emitBuiltinLinearFunction(
12041199
assert(args.size() == 2);
12051200
auto origFn = args.front();
12061201
auto origType = origFn.getType().castTo<SILFunctionType>();
1207-
// TODO(TF-1142): Create a linear_function instead of an undef.
1208-
auto linearFnTy = origType->getWithDifferentiability(
1209-
DifferentiabilityKind::Linear,
1202+
auto linearFn = SGF.B.createLinearFunction(
1203+
loc,
12101204
IndexSubset::getDefault(
1211-
SGF.getASTContext(), origType->getNumParameters(),
1212-
/*includeAll*/ true));
1213-
SILValue linearFn = SILUndef::get(
1214-
SILType::getPrimitiveObjectType(linearFnTy), SGF.F);
1205+
SGF.getASTContext(),
1206+
origType->getNumParameters(),
1207+
/*includeAll*/ true),
1208+
origFn.forward(SGF), args[1].forward(SGF));
12151209
return SGF.emitManagedRValueWithCleanup(linearFn);
12161210
}
12171211

1218-
1219-
12201212
/// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default
12211213
/// ownership convention for named builtins, which is to take (non-trivial)
12221214
/// arguments as Owned, this builtin accepts owned as well as guaranteed

lib/SILOptimizer/Transforms/SemanticARCOpts.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,10 @@ struct SemanticARCOptVisitor
776776
FORWARDING_INST(OpenExistentialBoxValue)
777777
FORWARDING_INST(MarkDependence)
778778
FORWARDING_INST(InitExistentialRef)
779+
FORWARDING_INST(DifferentiableFunction)
780+
FORWARDING_INST(LinearFunction)
781+
FORWARDING_INST(DifferentiableFunctionExtract)
782+
FORWARDING_INST(LinearFunctionExtract)
779783
#undef FORWARDING_INST
780784

781785
#define FORWARDING_TERM(NAME) \

stdlib/public/Differentiation/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
1414
Differentiable.swift
15+
DifferentialOperators.swift
16+
DifferentiationUtilities.swift
1517

16-
SWIFT_COMPILE_FLAGS ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
18+
SWIFT_COMPILE_FLAGS
19+
${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
20+
-parse-stdlib
21+
-Xfrontend -enable-experimental-differentiable-programming
1722
LINK_FLAGS "${SWIFT_RUNTIME_SWIFT_LINK_FLAGS}"
1823
INSTALL_IN_COMPONENT stdlib)

stdlib/public/Differentiation/Differentiable.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
//
2121
//===----------------------------------------------------------------------===//
2222

23+
import Swift
24+
2325
/// A type that mathematically represents a differentiable manifold whose
2426
/// tangent spaces are finite-dimensional.
2527
public protocol Differentiable {

0 commit comments

Comments
 (0)