Skip to content

Commit b4855ae

Browse files
committed
Adopt ValueDecl in autodiff diagnostics
1 parent 025728f commit b4855ae

File tree

5 files changed

+38
-53
lines changed

5 files changed

+38
-53
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3857,16 +3857,6 @@ ERROR(differentiable_attr_duplicate,none,
38573857
"duplicate '@differentiable' attribute with same parameters", ())
38583858
NOTE(differentiable_attr_duplicate_note,none,
38593859
"other attribute declared here", ())
3860-
ERROR(differentiable_attr_function_not_same_type_context,none,
3861-
"%0 is not defined in the current type context", (DeclNameRef))
3862-
ERROR(differentiable_attr_derivative_not_function,none,
3863-
"registered derivative %0 must be a 'func' declaration", (DeclNameRef))
3864-
ERROR(differentiable_attr_class_derivative_not_final,none,
3865-
"class member derivative must be final", ())
3866-
ERROR(differentiable_attr_invalid_access,none,
3867-
"derivative function %0 is required to either be public or "
3868-
"'@usableFromInline' because the original function %1 is public or "
3869-
"'@usableFromInline'", (DeclNameRef, DeclName))
38703860
ERROR(differentiable_attr_protocol_req_where_clause,none,
38713861
"'@differentiable' attribute on protocol requirement cannot specify "
38723862
"'where' clause", ())
@@ -3880,7 +3870,7 @@ ERROR(differentiable_attr_empty_where_clause,none,
38803870
"empty 'where' clause in '@differentiable' attribute", ())
38813871
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
38823872
"'where' clause is valid only when original function is generic %0",
3883-
(DeclName))
3873+
(const ValueDecl *))
38843874
ERROR(differentiable_attr_layout_req_unsupported,none,
38853875
"'@differentiable' attribute does not yet support layout requirements",
38863876
())
@@ -3890,7 +3880,7 @@ NOTE(protocol_witness_missing_differentiable_attr_invalid_context,none,
38903880
"candidate is missing explicit '%0' attribute to satisfy requirement %1 "
38913881
"(in protocol %3); explicit attribute is necessary because candidate is "
38923882
"declared in a different type context or file than the conformance of %2 "
3893-
"to %3", (StringRef, DeclName, Type, Type))
3883+
"to %3", (StringRef, const ValueDecl *, Type, Type))
38943884

38953885
// @derivative
38963886
ERROR(derivative_attr_expected_result_tuple,none,
@@ -3908,11 +3898,12 @@ ERROR(derivative_attr_result_value_not_differentiable,none,
39083898
"'@derivative(of:)' attribute requires function to return a two-element "
39093899
"tuple; first element type %0 must conform to 'Differentiable'", (Type))
39103900
ERROR(derivative_attr_result_func_type_mismatch,none,
3911-
"function result's %0 type does not match %1", (Identifier, DeclName))
3901+
"function result's %0 type does not match %1",
3902+
(Identifier, const ValueDecl *))
39123903
NOTE(derivative_attr_result_func_type_mismatch_note,none,
39133904
"%0 does not have expected type %1", (Identifier, Type))
39143905
NOTE(derivative_attr_result_func_original_note,none,
3915-
"%0 defined here", (DeclName))
3906+
"%0 defined here", (const ValueDecl *))
39163907
ERROR(derivative_attr_not_in_same_file_as_original,none,
39173908
"derivative not in the same file as the original function", ())
39183909
ERROR(derivative_attr_original_stored_property_unsupported,none,
@@ -3934,7 +3925,7 @@ NOTE(derivative_attr_protocol_requirement_unsupported,none,
39343925
"cannot yet register derivative default implementation for protocol "
39353926
"requirements", ())
39363927
ERROR(derivative_attr_original_already_has_derivative,none,
3937-
"a derivative already exists for %0", (DeclName))
3928+
"a derivative already exists for %0", (const ValueDecl *))
39383929
NOTE(derivative_attr_duplicate_note,none,
39393930
"other attribute declared here", ())
39403931
ERROR(derivative_attr_access_level_mismatch,none,
@@ -3943,31 +3934,29 @@ ERROR(derivative_attr_access_level_mismatch,none,
39433934
"%select{private|fileprivate|internal|package|public|open}3, "
39443935
"but original function %0 is "
39453936
"%select{private|fileprivate|internal|package|public|open}1",
3946-
(/*original*/ DeclName, /*original*/ AccessLevel,
3947-
/*derivative*/ DeclName, /*derivative*/ AccessLevel))
3937+
(/*original*/ const ValueDecl *, /*original*/ AccessLevel,
3938+
/*derivative*/ const ValueDecl *, /*derivative*/ AccessLevel))
39483939
NOTE(derivative_attr_fix_access,none,
39493940
"mark the derivative function as "
39503941
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
39513942
"to match the original function", (AccessLevel))
39523943
ERROR(derivative_attr_static_method_mismatch_original,none,
39533944
"unexpected derivative function declaration; "
39543945
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",
3955-
(/*original*/DeclName, /*derivative*/ DeclName,
3946+
(/*original*/const ValueDecl *, /*derivative*/ const ValueDecl *,
39563947
/*originalIsStatic*/bool))
39573948
NOTE(derivative_attr_static_method_mismatch_original_note,none,
39583949
"original function %0 is %select{an instance|a 'static'}1 method",
3959-
(/*original*/ DeclName, /*originalIsStatic*/bool))
3950+
(/*original*/ const ValueDecl *, /*originalIsStatic*/bool))
39603951
NOTE(derivative_attr_static_method_mismatch_fix,none,
39613952
"make derivative function %0 %select{an instance|a 'static'}1 method",
3962-
(/*derivative*/ DeclName, /*mustBeStatic*/bool))
3953+
(/*derivative*/ const ValueDecl *, /*mustBeStatic*/bool))
39633954

39643955
// @transpose
39653956
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,
39663957
"cannot transpose with respect to original %select{result|parameter}1 "
39673958
"'%0' that does not conform to 'Differentiable' and satisfy "
39683959
"'%0 == %0.TangentVector'", (StringRef, /*isParameter*/ bool))
3969-
ERROR(transpose_attr_overload_not_found,none,
3970-
"could not find function %0 with expected type %1", (DeclName, Type))
39713960
ERROR(transpose_attr_cannot_use_named_wrt_params,none,
39723961
"cannot use named 'wrt' parameters in '@transpose(of:)' attribute, found "
39733962
"%0", (Identifier))
@@ -3980,14 +3969,14 @@ NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
39803969
ERROR(transpose_attr_static_method_mismatch_original,none,
39813970
"unexpected transpose function declaration; "
39823971
"%0 requires the transpose function %1 to be %select{an instance|a 'static'}2 method",
3983-
(/*original*/DeclName, /*transpose*/ DeclName,
3972+
(/*original*/const ValueDecl *, /*transpose*/ const ValueDecl *,
39843973
/*originalIsStatic*/bool))
39853974
NOTE(transpose_attr_static_method_mismatch_original_note,none,
39863975
"original function %0 is %select{an instance|a 'static'}1 method",
3987-
(/*original*/ DeclName, /*originalIsStatic*/bool))
3976+
(/*original*/ const ValueDecl *, /*originalIsStatic*/bool))
39883977
NOTE(transpose_attr_static_method_mismatch_fix,none,
39893978
"make transpose function %0 %select{an instance|a 'static'}1 method",
3990-
(/*transpose*/ DeclName, /*mustBeStatic*/bool))
3979+
(/*transpose*/ const ValueDecl *, /*mustBeStatic*/bool))
39913980

39923981
// Automatic differentiation attributes
39933982
ERROR(autodiff_attr_original_decl_ambiguous,none,
@@ -4019,7 +4008,8 @@ ERROR(autodiff_attr_opaque_result_type_unsupported,none,
40194008

40204009
// differentiation `wrt` parameters clause
40214010
ERROR(diff_function_no_parameters,none,
4022-
"%0 has no parameters to differentiate with respect to", (DeclName))
4011+
"%0 has no parameters to differentiate with respect to",
4012+
(const ValueDecl *))
40234013
ERROR(diff_params_clause_param_name_unknown,none,
40244014
"unknown parameter name %0", (Identifier))
40254015
ERROR(diff_params_clause_self_instance_method_only,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4855,8 +4855,7 @@ static IndexSubset *computeDifferentiabilityParameters(
48554855
// If function is not an instance method, diagnose immediately.
48564856
if (!isInstanceMethod) {
48574857
diags
4858-
.diagnose(attrLoc, diag::diff_function_no_parameters,
4859-
function->getName())
4858+
.diagnose(attrLoc, diag::diff_function_no_parameters, function)
48604859
.highlight(function->getSignatureSourceRange());
48614860
return nullptr;
48624861
}
@@ -4870,8 +4869,7 @@ static IndexSubset *computeDifferentiabilityParameters(
48704869
selfType = function->mapTypeIntoContext(selfType);
48714870
if (!conformsToDifferentiable(selfType, module)) {
48724871
diags
4873-
.diagnose(attrLoc, diag::diff_function_no_parameters,
4874-
function->getName())
4872+
.diagnose(attrLoc, diag::diff_function_no_parameters, function)
48754873
.highlight(function->getSignatureSourceRange());
48764874
return nullptr;
48774875
}
@@ -5447,7 +5445,7 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
54475445
.diagnose(
54485446
attr->getLocation(),
54495447
diag::differentiable_attr_where_clause_for_nongeneric_original,
5450-
original->getName())
5448+
original)
54515449
.highlight(whereClause->getSourceRange());
54525450
attr->setInvalid();
54535451
return true;
@@ -6061,16 +6059,15 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
60616059
diags
60626060
.diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(),
60636061
diag::derivative_attr_static_method_mismatch_original,
6064-
originalAFD->getName(), derivative->getName(),
6065-
derivativeMustBeStatic)
6062+
originalAFD, derivative, derivativeMustBeStatic)
60666063
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
60676064
diags.diagnose(originalAFD->getNameLoc(),
60686065
diag::derivative_attr_static_method_mismatch_original_note,
6069-
originalAFD->getName(), derivativeMustBeStatic);
6066+
originalAFD, derivativeMustBeStatic);
60706067
auto fixItDiag =
60716068
diags.diagnose(derivative->getStartLoc(),
60726069
diag::derivative_attr_static_method_mismatch_fix,
6073-
derivative->getName(), derivativeMustBeStatic);
6070+
derivative, derivativeMustBeStatic);
60746071
if (derivativeMustBeStatic) {
60756072
fixItDiag.fixItInsert(derivative->getStartLoc(), "static ");
60766073
} else {
@@ -6098,8 +6095,8 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
60986095
derivative->getFormalAccessScope().accessLevelForDiagnostics();
60996096
diags.diagnose(originalName.Loc,
61006097
diag::derivative_attr_access_level_mismatch,
6101-
originalAFD->getName(), originalAccess,
6102-
derivative->getName(), derivativeAccess);
6098+
originalAFD, originalAccess,
6099+
derivative, derivativeAccess);
61036100
auto fixItDiag =
61046101
derivative->diagnose(diag::derivative_attr_fix_access, originalAccess);
61056102
// If original access is public, suggest adding `@usableFromInline` to
@@ -6207,7 +6204,7 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
62076204
// Emit differential/pullback type mismatch error on attribute.
62086205
diags.diagnose(attr->getLocation(),
62096206
diag::derivative_attr_result_func_type_mismatch,
6210-
funcResultElt.getName(), originalAFD->getName());
6207+
funcResultElt.getName(), originalAFD);
62116208
// Emit note with expected differential/pullback type on actual type
62126209
// location.
62136210
auto *tupleReturnTypeRepr =
@@ -6222,7 +6219,7 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
62226219
if (originalAFD->getLoc().isValid())
62236220
diags.diagnose(originalAFD->getLoc(),
62246221
diag::derivative_attr_result_func_original_note,
6225-
originalAFD->getName());
6222+
originalAFD);
62266223
return true;
62276224
}
62286225

@@ -6233,7 +6230,7 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
62336230
if (derivativeAttrs.size() > 1) {
62346231
diags.diagnose(attr->getLocation(),
62356232
diag::derivative_attr_original_already_has_derivative,
6236-
originalAFD->getName());
6233+
originalAFD);
62376234
for (auto *duplicateAttr : derivativeAttrs) {
62386235
if (duplicateAttr == attr)
62396236
continue;
@@ -6595,15 +6592,14 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
65956592
bool transposeMustBeStatic = !transpose->isStatic();
65966593
diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(),
65976594
diag::transpose_attr_static_method_mismatch_original,
6598-
originalAFD->getName(), transpose->getName(),
6599-
transposeMustBeStatic)
6595+
originalAFD, transpose, transposeMustBeStatic)
66006596
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
66016597
diagnose(originalAFD->getNameLoc(),
66026598
diag::transpose_attr_static_method_mismatch_original_note,
6603-
originalAFD->getName(), transposeMustBeStatic);
6599+
originalAFD, transposeMustBeStatic);
66046600
auto fixItDiag = diagnose(transpose->getStartLoc(),
66056601
diag::transpose_attr_static_method_mismatch_fix,
6606-
transpose->getName(), transposeMustBeStatic);
6602+
transpose, transposeMustBeStatic);
66076603
if (transposeMustBeStatic) {
66086604
fixItDiag.fixItInsert(transpose->getStartLoc(), "static ");
66096605
} else {

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,9 +2827,8 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
28272827
diags
28282828
.diagnose(
28292829
witness,
2830-
diag::
2831-
protocol_witness_missing_differentiable_attr_invalid_context,
2832-
reqDiffAttrString, req->getName(), conformance->getType(),
2830+
diag::protocol_witness_missing_differentiable_attr_invalid_context,
2831+
reqDiffAttrString, req, conformance->getType(),
28332832
conformance->getProtocol()->getDeclaredInterfaceType())
28342833
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
28352834
break;

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic {
504504
return (1, { _ in .zero })
505505
}
506506

507-
// expected-error @+2 {{a derivative already exists for '_'}}
507+
// expected-error @+2 {{a derivative already exists for getter for 'subscript()'}}
508508
// expected-note @-6 {{other attribute declared here}}
509509
@derivative(of: subscript)
510510
func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) {
@@ -521,7 +521,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic {
521521
return (1, { _ in .zero })
522522
}
523523

524-
// expected-error @+2 {{a derivative already exists for '_'}}
524+
// expected-error @+2 {{a derivative already exists for getter for 'subscript(float:)'}}
525525
// expected-note @-6 {{other attribute declared here}}
526526
@derivative(of: subscript(float:), wrt: self)
527527
func vjpSubscriptLabeled(float: Float) -> (value: Float, pullback: (Float) -> TangentVector) {
@@ -538,7 +538,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic {
538538
return (x, { _ in .zero })
539539
}
540540

541-
// expected-error @+2 {{a derivative already exists for '_'}}
541+
// expected-error @+2 {{a derivative already exists for getter for 'subscript(_:)'}}
542542
// expected-note @-6 {{other attribute declared here}}
543543
@derivative(of: subscript(_:), wrt: self)
544544
func vjpSubscriptGeneric<U: Differentiable>(x: U) -> (value: U, pullback: (U.TangentVector) -> TangentVector) {
@@ -619,7 +619,7 @@ extension Class where T: Differentiable {
619619
return (1, { _ in .zero })
620620
}
621621

622-
// expected-error @+2 {{a derivative already exists for '_'}}
622+
// expected-error @+2 {{a derivative already exists for getter for 'subscript()'}}
623623
// expected-note @-6 {{other attribute declared here}}
624624
@derivative(of: subscript)
625625
func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) {

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ let globalConst: Float = 1
1818
var globalVar: Float = 1
1919

2020
func testLocalVariables() {
21-
// expected-error @+1 {{'_' has no parameters to differentiate with respect to}}
21+
// expected-error @+1 {{getter for 'getter' has no parameters to differentiate with respect to}}
2222
@differentiable(reverse)
2323
var getter: Float {
2424
return 1
2525
}
2626

27-
// expected-error @+1 {{'_' has no parameters to differentiate with respect to}}
27+
// expected-error @+1 {{getter for 'getterSetter' has no parameters to differentiate with respect to}}
2828
@differentiable(reverse)
2929
var getterSetter: Float {
3030
get { return 1 }

0 commit comments

Comments
 (0)