Skip to content

Commit 9205758

Browse files
authored
[AutoDiff upstream] Add SILFunctionType utilities. (swiftlang#29674)
Upstream AutoDiff-related `SILFunctionType` utilities: - `SILFunctionType::getDifferentiabilityParameterIndices` - `SILFunctionType::getWithDifferentiability` - `SILFunctionType::getWithoutDifferentiability` Resolves TF-1125.
1 parent 609c84b commit 9205758

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

include/swift/AST/Types.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3087,6 +3087,14 @@ class AnyFunctionType : public TypeBase {
30873087
ExtInfo withClangFunctionType(const clang::Type *type) const {
30883088
return ExtInfo(Bits, Uncommon(type));
30893089
}
3090+
LLVM_NODISCARD
3091+
ExtInfo
3092+
withDifferentiabilityKind(DifferentiabilityKind differentiability) const {
3093+
return ExtInfo(
3094+
(Bits & ~DifferentiabilityMask) |
3095+
((unsigned)differentiability << DifferentiabilityMaskOffset),
3096+
Other);
3097+
}
30903098

30913099
std::pair<unsigned, const void *> getFuncAttrKey() const {
30923100
return std::make_pair(Bits, Other.ClangFunctionType);
@@ -3746,7 +3754,7 @@ class SILParameterInfo {
37463754

37473755
/// Return a version of this parameter info with the type replaced.
37483756
SILParameterInfo getWithInterfaceType(CanType type) const {
3749-
return SILParameterInfo(type, getConvention());
3757+
return SILParameterInfo(type, getConvention(), getDifferentiability());
37503758
}
37513759

37523760
/// Transform this SILParameterInfo by applying the user-provided
@@ -4435,6 +4443,22 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
44354443

44364444
const clang::FunctionType *getClangFunctionType() const;
44374445

4446+
/// Given that `this` is a `@differentiable` or `@differentiable(linear)`
4447+
/// function type, returns an `IndexSubset` corresponding to the
4448+
/// differentiability/linearity parameters (e.g. all parameters except the
4449+
/// `@noDerivative` ones).
4450+
IndexSubset *getDifferentiabilityParameterIndices();
4451+
4452+
/// Returns the `@differentiable` or `@differentiable(linear)` function type
4453+
/// for the given differentiability kind and parameter indices representing
4454+
/// differentiability/linearity parameters.
4455+
CanSILFunctionType getWithDifferentiability(DifferentiabilityKind kind,
4456+
IndexSubset *parameterIndices);
4457+
4458+
/// Returns the SIL function type stripping differentiability kind and
4459+
/// differentiability from all parameters.
4460+
CanSILFunctionType getWithoutDifferentiability();
4461+
44384462
/// Returns the type of the derivative function for the given parameter
44394463
/// indices, result index, derivative function kind, derivative function
44404464
/// generic signature (optional), and other auxiliary parameters.

lib/SIL/SILFunctionType.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,56 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const {
191191
return nullptr;
192192
}
193193

194+
IndexSubset *
195+
SILFunctionType::getDifferentiabilityParameterIndices() {
196+
assert(isDifferentiable() && "Must be a differentiable function");
197+
SmallVector<unsigned, 8> result;
198+
for (auto valueAndIndex : enumerate(getParameters()))
199+
if (valueAndIndex.value().getDifferentiability() !=
200+
SILParameterDifferentiability::NotDifferentiable)
201+
result.push_back(valueAndIndex.index());
202+
return IndexSubset::get(getASTContext(), getNumParameters(), result);
203+
}
204+
205+
CanSILFunctionType
206+
SILFunctionType::getWithDifferentiability(DifferentiabilityKind kind,
207+
IndexSubset *parameterIndices) {
208+
assert(kind != DifferentiabilityKind::NonDifferentiable &&
209+
"Differentiability kind must be normal or linear");
210+
SmallVector<SILParameterInfo, 8> newParameters;
211+
for (auto paramAndIndex : enumerate(getParameters())) {
212+
auto &param = paramAndIndex.value();
213+
unsigned index = paramAndIndex.index();
214+
newParameters.push_back(param.getWithDifferentiability(
215+
index < parameterIndices->getCapacity() &&
216+
parameterIndices->contains(index)
217+
? SILParameterDifferentiability::DifferentiableOrNotApplicable
218+
: SILParameterDifferentiability::NotDifferentiable));
219+
}
220+
auto newExtInfo = getExtInfo().withDifferentiabilityKind(kind);
221+
return get(getSubstGenericSignature(), newExtInfo, getCoroutineKind(),
222+
getCalleeConvention(), newParameters, getYields(), getResults(),
223+
getOptionalErrorResult(), getSubstitutions(),
224+
isGenericSignatureImplied(), getASTContext(),
225+
getWitnessMethodConformanceOrInvalid());
226+
}
227+
228+
CanSILFunctionType SILFunctionType::getWithoutDifferentiability() {
229+
if (!isDifferentiable())
230+
return CanSILFunctionType(this);
231+
auto nondiffExtInfo = getExtInfo().withDifferentiabilityKind(
232+
DifferentiabilityKind::NonDifferentiable);
233+
SmallVector<SILParameterInfo, 8> newParams;
234+
for (auto &param : getParameters())
235+
newParams.push_back(param.getWithDifferentiability(
236+
SILParameterDifferentiability::DifferentiableOrNotApplicable));
237+
return SILFunctionType::get(getSubstGenericSignature(), nondiffExtInfo,
238+
getCoroutineKind(), getCalleeConvention(),
239+
newParams, getYields(), getResults(),
240+
getOptionalErrorResult(), getSubstitutions(),
241+
isGenericSignatureImplied(), getASTContext());
242+
}
243+
194244
CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
195245
IndexSubset *parameterIndices, unsigned resultIndex,
196246
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,

0 commit comments

Comments
 (0)