Skip to content

Commit 4bea441

Browse files
authored
Merge pull request swiftlang#39416 from slavapestov/autodiff-requirement-machine-workaround
AutoDiff: Workaround for performing generic signature queries on the wrong signature
2 parents 015d487 + bbb7196 commit 4bea441

File tree

51 files changed

+123
-85
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+123
-85
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/Module.h"
2727
#include "swift/AST/ModuleLoader.h"
2828
#include "swift/AST/ProtocolConformance.h"
29+
#include "swift/AST/TypeCheckRequests.h"
2930
#include "swift/ClangImporter/ClangImporter.h"
3031
#include "swift/SIL/SILModule.h"
3132
#include "swift/SIL/SILType.h"
@@ -360,6 +361,41 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
360361
IndexSubset::get(C, parameterIndices->getCapacity(), inoutParamIndices);
361362
}
362363

364+
static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignature sig,
365+
CanType tanType) {
366+
if (!sig)
367+
return sig;
368+
369+
llvm::DenseSet<CanType> types;
370+
371+
auto &ctx = tanType->getASTContext();
372+
373+
(void) tanType.findIf([&](Type t) -> bool {
374+
if (auto *dmt = t->getAs<DependentMemberType>()) {
375+
if (dmt->getName() == ctx.Id_TangentVector)
376+
types.insert(dmt->getBase()->getCanonicalType());
377+
}
378+
379+
return false;
380+
});
381+
382+
SmallVector<Requirement, 2> reqs;
383+
auto *proto = ctx.getProtocol(KnownProtocolKind::Differentiable);
384+
assert(proto != nullptr);
385+
386+
for (auto type : types) {
387+
if (!sig->requiresProtocol(type, proto)) {
388+
reqs.push_back(Requirement(RequirementKind::Conformance, type,
389+
proto->getDeclaredInterfaceType()));
390+
}
391+
}
392+
393+
return evaluateOrDefault(
394+
ctx.evaluator,
395+
AbstractGenericSignatureRequest{sig.getPointer(), {}, reqs},
396+
GenericSignature()).getCanonicalSignature();
397+
}
398+
363399
/// Returns the differential type for the given original function type,
364400
/// parameter indices, and result index.
365401
static CanSILFunctionType getAutoDiffDifferentialType(
@@ -371,10 +407,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
371407
auto getTangentParameterConvention =
372408
[&](CanType tanType,
373409
ParameterConvention origParamConv) -> ParameterConvention {
374-
tanType =
375-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
376-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
377-
tanType);
410+
auto sig = buildDifferentiableGenericSignature(
411+
originalFnTy->getSubstGenericSignature(), tanType);
412+
413+
tanType = tanType->getCanonicalType(sig);
414+
AbstractionPattern pattern(sig, tanType);
378415
auto &tl =
379416
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
380417
// When the tangent type is address only, we must ensure that the tangent
@@ -398,10 +435,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
398435
auto getTangentResultConvention =
399436
[&](CanType tanType,
400437
ResultConvention origResConv) -> ResultConvention {
401-
tanType =
402-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
403-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
404-
tanType);
438+
auto sig = buildDifferentiableGenericSignature(
439+
originalFnTy->getSubstGenericSignature(), tanType);
440+
441+
tanType = tanType->getCanonicalType(sig);
442+
AbstractionPattern pattern(sig, tanType);
405443
auto &tl =
406444
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
407445
// When the tangent type is address only, we must ensure that the tangent
@@ -530,10 +568,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
530568
auto getTangentParameterConventionForOriginalResult =
531569
[&](CanType tanType,
532570
ResultConvention origResConv) -> ParameterConvention {
533-
tanType =
534-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
535-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
536-
tanType);
571+
auto sig = buildDifferentiableGenericSignature(
572+
originalFnTy->getSubstGenericSignature(), tanType);
573+
574+
tanType = tanType->getCanonicalType(sig);
575+
AbstractionPattern pattern(sig, tanType);
537576
auto &tl =
538577
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
539578
ParameterConvention conv;
@@ -560,10 +599,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
560599
auto getTangentResultConventionForOriginalParameter =
561600
[&](CanType tanType,
562601
ParameterConvention origParamConv) -> ResultConvention {
563-
tanType =
564-
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
565-
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
566-
tanType);
602+
auto sig = buildDifferentiableGenericSignature(
603+
originalFnTy->getSubstGenericSignature(), tanType);
604+
605+
tanType = tanType->getCanonicalType(sig);
606+
AbstractionPattern pattern(sig, tanType);
567607
auto &tl =
568608
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
569609
ResultConvention conv;

stdlib/private/DifferentiationUnittest/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@ add_swift_target_library(swiftDifferentiationUnittest ${SWIFT_STDLIB_LIBRARY_BUI
33
GYB_SOURCES DifferentiationUnittest.swift.gyb
44

55
SWIFT_MODULE_DEPENDS _Differentiation StdlibUnittest
6-
SWIFT_COMPILE_FLAGS -Xfrontend -requirement-machine=off
76
INSTALL_IN_COMPONENT stdlib-experimental
87
DARWIN_INSTALL_NAME_DIR "${SWIFT_DARWIN_STDLIB_PRIVATE_INSTALL_NAME_DIR}")

stdlib/public/Differentiation/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,5 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
4141
SWIFT_COMPILE_FLAGS
4242
${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
4343
-parse-stdlib
44-
-Xfrontend -requirement-machine=off
4544
LINK_FLAGS "${SWIFT_RUNTIME_SWIFT_LINK_FLAGS}"
4645
INSTALL_IN_COMPONENT stdlib)

test/AutoDiff/IRGen/differentiable_function_type.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-ir -g %s -requirement-machine=off
1+
// RUN: %target-swift-frontend -emit-ir -g %s
22

33
import _Differentiation
44

test/AutoDiff/IRGen/loadable_by_address.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
// RUN: %target-swift-frontend -c -Xllvm -sil-verify-after-pass=loadable-address %s -requirement-machine=off
2-
// RUN: %target-swift-frontend -emit-sil %s -requirement-machine=off | %FileCheck %s -check-prefix=CHECK-SIL
3-
// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %s -requirement-machine=off 2>&1 | %FileCheck %s -check-prefix=CHECK-LBA-SIL
4-
// RUN: %target-run-simple-swift(-Xfrontend -requirement-machine=off)
1+
// RUN: %target-swift-frontend -c -Xllvm -sil-verify-after-pass=loadable-address %s
2+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -check-prefix=CHECK-SIL
3+
// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %s 2>&1 | %FileCheck %s -check-prefix=CHECK-LBA-SIL
4+
// RUN: %target-run-simple-swift
55
// REQUIRES: executable_test
66

77
// `isLargeLoadableType` depends on the ABI and differs between architectures.

test/AutoDiff/IRGen/loadable_by_address_cross_module.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
// First, check that LBA actually modifies the function, so that this test is useful.
22

3-
// RUN: %target-swift-frontend -emit-sil %S/Inputs/loadable_by_address_cross_module.swift -requirement-machine=off | %FileCheck %s -check-prefix=CHECK-MODULE-PRE-LBA
4-
// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %S/Inputs/loadable_by_address_cross_module.swift -requirement-machine=off 2>&1 | %FileCheck %s -check-prefix=CHECK-MODULE-POST-LBA
3+
// RUN: %target-swift-frontend -emit-sil %S/Inputs/loadable_by_address_cross_module.swift | %FileCheck %s -check-prefix=CHECK-MODULE-PRE-LBA
4+
// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %S/Inputs/loadable_by_address_cross_module.swift 2>&1 | %FileCheck %s -check-prefix=CHECK-MODULE-POST-LBA
55

66
// CHECK-MODULE-PRE-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) <T> (Float, LargeLoadableType<T>) -> Float
77
// CHECK-MODULE-POST-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) <T> (Float, @in_constant LargeLoadableType<T>) -> Float
88

99
// Compile the module.
1010

1111
// RUN: %empty-directory(%t)
12-
// RUN: %target-build-swift-dylib(%t/%target-library-name(external)) %S/Inputs/loadable_by_address_cross_module.swift -emit-module -emit-module-path %t/external.swiftmodule -module-name external -Xfrontend -requirement-machine=off
12+
// RUN: %target-build-swift-dylib(%t/%target-library-name(external)) %S/Inputs/loadable_by_address_cross_module.swift -emit-module -emit-module-path %t/external.swiftmodule -module-name external
1313

1414
// Next, check that differentiability_witness_functions in the client get
1515
// correctly modified by LBA.
1616

17-
// RUN: %target-swift-frontend -emit-sil -I%t %s -requirement-machine=off
18-
// RUN: %target-swift-frontend -emit-sil -I%t %s -requirement-machine=off | %FileCheck %s -check-prefix=CHECK-CLIENT-PRE-LBA
19-
// RUN: %target-swift-frontend -c -I%t %s -Xllvm -sil-print-after=loadable-address -requirement-machine=off 2>&1 | %FileCheck %s -check-prefix=CHECK-CLIENT-POST-LBA
17+
// RUN: %target-swift-frontend -emit-sil -I%t %s
18+
// RUN: %target-swift-frontend -emit-sil -I%t %s | %FileCheck %s -check-prefix=CHECK-CLIENT-PRE-LBA
19+
// RUN: %target-swift-frontend -c -I%t %s -Xllvm -sil-print-after=loadable-address 2>&1 | %FileCheck %s -check-prefix=CHECK-CLIENT-POST-LBA
2020

2121
// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [reverse] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
2222
// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [reverse] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
@@ -26,7 +26,7 @@
2626

2727
// Finally, execute the test.
2828

29-
// RUN: %target-build-swift -I%t -L%t %s -o %t/a.out %target-rpath(%t) -L%t -lexternal -Xfrontend -requirement-machine=off
29+
// RUN: %target-build-swift -I%t -L%t %s -o %t/a.out %target-rpath(%t) -L%t -lexternal
3030
// RUN: %target-codesign %t/a.out
3131
// RUN: %target-codesign %t/%target-library-name(external)
3232
// RUN: %target-run %t/a.out %t/%target-library-name(external)

test/AutoDiff/SIL/Parse/sildeclref.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine off | %FileCheck %s
1+
// RUN: %target-sil-opt %s -module-name=sildeclref_parse -requirement-machine=off | %target-sil-opt -module-name=sildeclref_parse -requirement-machine=off | %FileCheck %s
22
// Parse AutoDiff derivative SILDeclRefs via `witness_method` and `class_method` instructions.
33

44
import Swift

test/AutoDiff/SILGen/autodiff_builtins.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -parse-stdlib -emit-silgen %s -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-swift-frontend -parse-stdlib -emit-silgen %s | %FileCheck %s
22

33
import _Differentiation
44
import Swift

test/AutoDiff/SILGen/reabstraction.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-silgen %s -requirement-machine=off | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
22

33
import _Differentiation
44

test/AutoDiff/SILOptimizer/generics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-emit-sil -verify %s -requirement-machine=off | %FileCheck %s -check-prefix=CHECK-SIL
1+
// RUN: %target-swift-emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
22

33
import _Differentiation
44

0 commit comments

Comments
 (0)