Skip to content

Commit 0c1d4b5

Browse files
committed
[AutoDiff] Enable cross-file derivative registration.
Lift temporary cross-file derivative registration restriction. `@derivative` attribute type-checking simplications coming soon: TF-1099. Original function and derivative function must have same access level, with one exception: public original functions may have internal `@usableFromInline` derivatives.
1 parent bb0aa1c commit 0c1d4b5

File tree

2 files changed

+1
-11
lines changed

2 files changed

+1
-11
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4423,14 +4423,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44234423
return true;
44244424
}
44254425

4426-
// Reject different-file derivative registration.
4427-
// TODO(TF-1021): Lift same-file derivative registration restriction.
4428-
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
4429-
diags.diagnose(attr->getLocation(),
4430-
diag::derivative_attr_not_in_same_file_as_original);
4431-
return true;
4432-
}
4433-
44344426
// Reject duplicate `@derivative` attributes.
44354427
auto &derivativeAttrs = Ctx.DerivativeAttrs[std::make_tuple(
44364428
originalAFD, resolvedDiffParamIndices, kind)];

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -743,11 +743,9 @@ extension InoutParameters {
743743
}
744744
}
745745

746-
// Test cross-file derivative registration. Currently unsupported.
747-
// TODO(TF-1021): Lift this restriction.
746+
// Test cross-file derivative registration.
748747

749748
extension FloatingPoint where Self: Differentiable {
750-
// expected-error @+1 {{derivative not in the same file as the original function}}
751749
@derivative(of: rounded)
752750
func vjpRounded() -> (
753751
value: Self,

0 commit comments

Comments
 (0)