Skip to content

Commit f6b191b

Browse files
committed
AutoDiff: Use getLocalProtocols() instead of getInheritedProtocols()
1 parent c530baa commit f6b191b

File tree

1 file changed

+6
-49
lines changed

1 file changed

+6
-49
lines changed

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 6 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -337,51 +337,6 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
337337
C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr});
338338
}
339339

340-
/// Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
341-
///
342-
/// Precondition: `decl` is a nominal type decl or an extension decl.
343-
void getInheritedProtocols(Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
344-
ArrayRef<TypeLoc> inheritedTypeLocs;
345-
if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
346-
inheritedTypeLocs = nominalDecl->getInherited();
347-
else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
348-
inheritedTypeLocs = extDecl->getInherited();
349-
else
350-
llvm_unreachable("conformance is not a nominal or an extension");
351-
352-
std::function<void(Type)> handleInheritedType;
353-
354-
auto handleProto = [&](ProtocolType *proto) -> void {
355-
proto->getDecl()->walkInheritedProtocols([&](ProtocolDecl *p) -> TypeWalker::Action {
356-
protos.insert(p);
357-
return TypeWalker::Action::Continue;
358-
});
359-
};
360-
361-
auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
362-
for (auto ty : comp->getMembers())
363-
handleInheritedType(ty);
364-
};
365-
366-
handleInheritedType = [&](Type ty) -> void {
367-
if (auto *proto = ty->getAs<ProtocolType>())
368-
handleProto(proto);
369-
else if (auto *comp = ty->getAs<ProtocolCompositionType>())
370-
handleProtoComp(comp);
371-
};
372-
373-
for (auto loc : inheritedTypeLocs) {
374-
if (loc.getTypeRepr())
375-
handleInheritedType(
376-
TypeResolution::forStructural(cast<DeclContext>(decl), None,
377-
/*unboundTyOpener*/ nullptr,
378-
/*placeholderHandler*/ nullptr)
379-
.resolveType(loc.getTypeRepr()));
380-
else
381-
handleInheritedType(loc.getType());
382-
}
383-
}
384-
385340
/// Return associated `TangentVector` struct for a nominal type, if it exists.
386341
/// If not, synthesize the struct.
387342
static StructDecl *
@@ -409,12 +364,14 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
409364
// Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
410365
// the `Differentiable` protocol itself requires that its `TangentVector` conforms to
411366
// `AdditiveArithmetic` and `Differentiable`.
412-
llvm::SmallPtrSet<ProtocolDecl *, 4> tvDesiredProtos;
413-
llvm::SmallPtrSet<ProtocolDecl *, 4> conformanceInheritedProtos;
414-
getInheritedProtocols(derived.ConformanceDecl, conformanceInheritedProtos);
367+
llvm::SmallSetVector<ProtocolDecl *, 4> tvDesiredProtos;
368+
415369
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
416370
auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector);
417-
for (auto proto : conformanceInheritedProtos) {
371+
372+
auto localProtos = cast<IterableDeclContext>(derived.ConformanceDecl)
373+
->getLocalProtocols();
374+
for (auto proto : localProtos) {
418375
for (auto req : proto->getRequirementSignature()) {
419376
if (req.getKind() != RequirementKind::Conformance)
420377
continue;

0 commit comments

Comments
 (0)