Skip to content

Commit 83b50cd

Browse files
authored
[AutoDiff] Add missing withoutDerivative(at:) fix-its. (swiftlang#33660)
Add `withoutDerivative(at:)` fix-its for errors regarding non-differentiable arguments and results.
1 parent 9993875 commit 83b50cd

File tree

4 files changed

+37
-16
lines changed

4 files changed

+37
-16
lines changed

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,14 @@ class JVPCloner::Implementation final
553553
if (!originalFnTy->getParameters()[paramIndex]
554554
.getSILStorageInterfaceType()
555555
.isDifferentiable(getModule())) {
556-
context.emitNondifferentiabilityError(
557-
ai->getArgumentsWithoutIndirectResults()[paramIndex],
558-
invoker, diag::autodiff_nondifferentiable_argument);
556+
auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex];
557+
auto startLoc = arg.getLoc().getStartSourceLoc();
558+
auto endLoc = arg.getLoc().getEndSourceLoc();
559+
context
560+
.emitNondifferentiabilityError(
561+
arg, invoker, diag::autodiff_nondifferentiable_argument)
562+
.fixItInsert(startLoc, "withoutDerivative(at: ")
563+
.fixItInsertAfter(endLoc, ")");
559564
errorOccurred = true;
560565
return true;
561566
}
@@ -573,9 +578,14 @@ class JVPCloner::Implementation final
573578
.getSILStorageInterfaceType();
574579
}
575580
if (!remappedResultType.isDifferentiable(getModule())) {
576-
context.emitNondifferentiabilityError(
577-
origCallee, invoker,
578-
diag::autodiff_nondifferentiable_result);
581+
auto startLoc = ai->getLoc().getStartSourceLoc();
582+
auto endLoc = ai->getLoc().getEndSourceLoc();
583+
context
584+
.emitNondifferentiabilityError(
585+
origCallee, invoker,
586+
diag::autodiff_nondifferentiable_result)
587+
.fixItInsert(startLoc, "withoutDerivative(at: ")
588+
.fixItInsertAfter(endLoc, ")");
579589
errorOccurred = true;
580590
return true;
581591
}

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,14 @@ class VJPCloner::Implementation final
457457
if (!originalFnTy->getParameters()[paramIndex]
458458
.getSILStorageInterfaceType()
459459
.isDifferentiable(getModule())) {
460-
context.emitNondifferentiabilityError(
461-
ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker,
462-
diag::autodiff_nondifferentiable_argument);
460+
auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex];
461+
auto startLoc = arg.getLoc().getStartSourceLoc();
462+
auto endLoc = arg.getLoc().getEndSourceLoc();
463+
context
464+
.emitNondifferentiabilityError(
465+
arg, invoker, diag::autodiff_nondifferentiable_argument)
466+
.fixItInsert(startLoc, "withoutDerivative(at: ")
467+
.fixItInsertAfter(endLoc, ")");
463468
errorOccurred = true;
464469
return true;
465470
}
@@ -477,8 +482,14 @@ class VJPCloner::Implementation final
477482
.getSILStorageInterfaceType();
478483
}
479484
if (!remappedResultType.isDifferentiable(getModule())) {
480-
context.emitNondifferentiabilityError(
481-
origCallee, invoker, diag::autodiff_nondifferentiable_result);
485+
auto startLoc = ai->getLoc().getStartSourceLoc();
486+
auto endLoc = ai->getLoc().getEndSourceLoc();
487+
context
488+
.emitNondifferentiabilityError(
489+
origCallee, invoker,
490+
diag::autodiff_nondifferentiable_result)
491+
.fixItInsert(startLoc, "withoutDerivative(at: ")
492+
.fixItInsertAfter(endLoc, ")");
482493
errorOccurred = true;
483494
return true;
484495
}

test/AutoDiff/SILOptimizer/differentiation_control_flow_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func loop_array(_ array: [Float]) -> Float {
171171
var result: Float = 1
172172
// TODO(TF-957): Improve non-differentiability errors for for-in loops
173173
// (`Collection.makeIterator` and `IteratorProtocol.next`).
174-
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
174+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}} {{12-12=withoutDerivative(at: }} {{17-17=)}}
175175
for x in array {
176176
result = result * x
177177
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,14 @@ struct TF_687<T> : Differentiable {
301301
}
302302
}
303303
// expected-error @+2 {{function is not differentiable}}
304-
// expected-note @+1 {{cannot differentiate through a non-differentiable argument; do you want to use 'withoutDerivative(at:)'?}}
304+
// expected-note @+1 {{cannot differentiate through a non-differentiable argument; do you want to use 'withoutDerivative(at:)'?}} {{78-78=withoutDerivative(at: }} {{79-79=)}}
305305
let _: @differentiable (Float) -> TF_687<Any> = { x in TF_687<Any>(x, dummy: x) }
306306

307307
// expected-error @+1 {{function is not differentiable}}
308308
@differentiable
309309
// expected-note @+1 {{when differentiating this function definition}}
310310
func roundingGivesError(x: Float) -> Float {
311-
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
311+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}} {{16-16=withoutDerivative(at: }} {{22-22=)}}
312312
return Float(Int(x))
313313
}
314314

@@ -688,7 +688,7 @@ func differentiableProjectedValueAccess(_ s: Struct) -> Float {
688688
// expected-note @+2 {{when differentiating this function definition}}
689689
@differentiable
690690
func projectedValueAccess(_ s: Struct) -> Float {
691-
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
691+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}} {{3-3=withoutDerivative(at: }} {{7-7=)}}
692692
s.$y.wrappedValue
693693
}
694694

@@ -714,7 +714,7 @@ func modify(_ s: Struct, _ x: Float) -> Float {
714714
func tupleArrayLiteralInitialization(_ x: Float, _ y: Float) -> Float {
715715
// `Array<(Float, Float)>` does not conform to `Differentiable`.
716716
let array = [(x * y, x * y)]
717-
// expected-note @+1 {{cannot differentiate through a non-differentiable argument; do you want to use 'withoutDerivative(at:)'?}}
717+
// expected-note @+1 {{cannot differentiate through a non-differentiable argument; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at: }} {{15-15=)}}
718718
return array[0].0
719719
}
720720

0 commit comments

Comments
 (0)