@@ -195,6 +195,7 @@ struct AutoDiffConfig {
195
195
IndexSubset *resultIndices;
196
196
GenericSignature derivativeGenericSignature;
197
197
198
+ /* implicit*/ AutoDiffConfig() = default ;
198
199
/* implicit*/ AutoDiffConfig(
199
200
IndexSubset *parameterIndices, IndexSubset *resultIndices,
200
201
GenericSignature derivativeGenericSignature = GenericSignature())
@@ -545,10 +546,20 @@ struct TangentPropertyInfo {
545
546
546
547
void simple_display (llvm::raw_ostream &OS, TangentPropertyInfo info);
547
548
548
- // / The key type used for uniquing `SILDifferentiabilityWitness` in
549
- // / `SILModule`: original function name, parameter indices, result indices, and
550
- // / derivative generic signature.
551
- using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
549
+ // / The key type used for uniquing `SILDifferentiabilityWitness` in `SILModule`.
550
+ struct SILDifferentiabilityWitnessKey {
551
+ StringRef originalFunctionName;
552
+ DifferentiabilityKind kind;
553
+ AutoDiffConfig config;
554
+
555
+ void print (llvm::raw_ostream &s = llvm::outs()) const ;
556
+ };
557
+
558
+ inline llvm::raw_ostream &operator <<(
559
+ llvm::raw_ostream &s, const SILDifferentiabilityWitnessKey &key) {
560
+ key.print (s);
561
+ return s;
562
+ }
552
563
553
564
// / Returns `true` iff differentiable programming is enabled.
554
565
bool isDifferentiableProgrammingEnabled (SourceFile &SF);
@@ -676,6 +687,9 @@ getAutoDiffFunctionKind(AutoDiffDerivativeFunctionKind kind);
676
687
677
688
AutoDiffFunctionKind getAutoDiffFunctionKind (AutoDiffLinearMapKind kind);
678
689
690
+ MangledDifferentiabilityKind
691
+ getMangledDifferentiabilityKind (DifferentiabilityKind kind);
692
+
679
693
} // end namespace autodiff
680
694
} // end namespace swift
681
695
@@ -688,6 +702,8 @@ using swift::GenericSignature;
688
702
using swift::IndexSubset;
689
703
using swift::SILAutoDiffDerivativeFunctionKey;
690
704
using swift::SILFunctionType;
705
+ using swift::DifferentiabilityKind;
706
+ using swift::SILDifferentiabilityWitnessKey;
691
707
692
708
template <typename T> struct DenseMapInfo ;
693
709
@@ -760,8 +776,8 @@ template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
760
776
};
761
777
762
778
template <> struct DenseMapInfo <SILAutoDiffDerivativeFunctionKey> {
763
- static bool isEqual (const SILAutoDiffDerivativeFunctionKey lhs,
764
- const SILAutoDiffDerivativeFunctionKey rhs) {
779
+ static bool isEqual (const SILAutoDiffDerivativeFunctionKey & lhs,
780
+ const SILAutoDiffDerivativeFunctionKey & rhs) {
765
781
return lhs.originalType == rhs.originalType &&
766
782
lhs.parameterIndices == rhs.parameterIndices &&
767
783
lhs.resultIndices == rhs.resultIndices &&
@@ -803,6 +819,36 @@ template <> struct DenseMapInfo<SILAutoDiffDerivativeFunctionKey> {
803
819
}
804
820
};
805
821
822
+ template <> struct DenseMapInfo <SILDifferentiabilityWitnessKey> {
823
+ static bool isEqual (const SILDifferentiabilityWitnessKey &lhs,
824
+ const SILDifferentiabilityWitnessKey &rhs) {
825
+ return DenseMapInfo<StringRef>::isEqual (
826
+ lhs.originalFunctionName , rhs.originalFunctionName ) &&
827
+ DenseMapInfo<unsigned >::isEqual (
828
+ (unsigned )lhs.kind , (unsigned )rhs.kind ) &&
829
+ DenseMapInfo<AutoDiffConfig>::isEqual (lhs.config , rhs.config );
830
+ }
831
+
832
+ static inline SILDifferentiabilityWitnessKey getEmptyKey () {
833
+ return {DenseMapInfo<StringRef>::getEmptyKey (),
834
+ (DifferentiabilityKind)DenseMapInfo<unsigned >::getEmptyKey (),
835
+ DenseMapInfo<AutoDiffConfig>::getEmptyKey ()};
836
+ }
837
+
838
+ static inline SILDifferentiabilityWitnessKey getTombstoneKey () {
839
+ return {DenseMapInfo<StringRef>::getTombstoneKey (),
840
+ (DifferentiabilityKind)DenseMapInfo<unsigned >::getTombstoneKey (),
841
+ DenseMapInfo<AutoDiffConfig>::getTombstoneKey ()};
842
+ }
843
+
844
+ static unsigned getHashValue (const SILDifferentiabilityWitnessKey &val) {
845
+ return hash_combine (
846
+ DenseMapInfo<StringRef>::getHashValue (val.originalFunctionName ),
847
+ DenseMapInfo<unsigned >::getHashValue ((unsigned )val.kind ),
848
+ DenseMapInfo<AutoDiffConfig>::getHashValue (val.config ));
849
+ }
850
+ };
851
+
806
852
} // end namespace llvm
807
853
808
854
#endif // SWIFT_AST_AUTODIFF_H
0 commit comments