Skip to content

Commit ab7a95c

Browse files
committed
[Sema] Delay function representation check until placeholders are bound
1 parent 8529126 commit ab7a95c

11 files changed

+201
-107
lines changed

lib/Sema/CSApply.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6644,6 +6644,19 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
66446644
Optional<Pattern*> typeFromPattern) {
66456645
auto &ctx = cs.getASTContext();
66466646

6647+
// Diagnose conversions to invalid function types that couldn't be performed
6648+
// beforehand because of placeholders.
6649+
if (auto *fnTy = toType->getAs<FunctionType>()) {
6650+
auto contextTy = cs.getContextualType(expr);
6651+
if (cs.getConstraintLocator(locator)->isForContextualType() && contextTy &&
6652+
contextTy->hasPlaceholder()) {
6653+
bool hadError = TypeChecker::diagnoseInvalidFunctionType(
6654+
fnTy, expr->getLoc(), None, dc, None);
6655+
if (hadError)
6656+
return nullptr;
6657+
}
6658+
}
6659+
66476660
// The type we're converting from.
66486661
Type fromType = cs.getType(expr);
66496662

lib/Sema/CSGen.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2079,7 +2079,9 @@ namespace {
20792079
extInfo = extInfo.withGlobalActor(getExplicitGlobalActor(closure));
20802080
}
20812081

2082-
return FunctionType::get(closureParams, resultTy, extInfo);
2082+
auto *fnTy = FunctionType::get(closureParams, resultTy, extInfo);
2083+
return CS.replaceInferableTypesWithTypeVars(
2084+
fnTy, CS.getConstraintLocator(closure))->castTo<FunctionType>();
20832085
}
20842086

20852087
/// Produces a type for the given pattern, filling in any missing

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ bool TypeChecker::typeCheckBinding(
446446
// Assign error types to the pattern and its variables, to prevent it from
447447
// being referenced by the constraint system.
448448
if (patternType->hasUnresolvedType() ||
449+
patternType->hasPlaceholder() ||
449450
patternType->hasUnboundGenericType()) {
450451
pattern->setType(ErrorType::get(Context));
451452
}

lib/Sema/TypeCheckStorage.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ PatternBindingEntryRequest::evaluate(Evaluator &eval,
265265
// If the pattern contains some form of unresolved type, we'll need to
266266
// check the initializer.
267267
if (patternType->hasUnresolvedType() ||
268+
patternType->hasPlaceholder() ||
268269
patternType->hasUnboundGenericType()) {
269270
if (TypeChecker::typeCheckPatternBinding(binding, entryNumber,
270271
patternType)) {

lib/Sema/TypeCheckType.cpp

Lines changed: 4 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,11 +1896,6 @@ namespace {
18961896
NeverNullType resolveOpaqueReturnType(TypeRepr *repr, StringRef mangledName,
18971897
unsigned ordinal,
18981898
TypeResolutionOptions options);
1899-
1900-
/// Returns true if the given type conforms to `Differentiable` in the
1901-
/// module of `DC`. If `tangentVectorEqualsSelf` is true, returns true iff
1902-
/// the given type additionally satisfies `Self == Self.TangentVector`.
1903-
bool isDifferentiable(Type type, bool tangentVectorEqualsSelf = false);
19041899
};
19051900
} // end anonymous namespace
19061901

@@ -2774,50 +2769,6 @@ TypeResolver::resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr,
27742769
inputRepr->getElementName(i));
27752770
}
27762771

2777-
// All non-`@noDerivative` parameters of `@differentiable` function types must
2778-
// be differentiable.
2779-
if (diffKind != DifferentiabilityKind::NonDifferentiable &&
2780-
resolution.getStage() != TypeResolutionStage::Structural) {
2781-
bool isLinear = diffKind == DifferentiabilityKind::Linear;
2782-
// Emit `@noDerivative` fixit only if there is at least one valid
2783-
// differentiability parameter. Otherwise, adding `@noDerivative` produces
2784-
// an ill-formed function type.
2785-
auto hasValidDifferentiabilityParam =
2786-
llvm::find_if(elements, [&](AnyFunctionType::Param param) {
2787-
if (param.isNoDerivative())
2788-
return false;
2789-
return isDifferentiable(param.getPlainType(),
2790-
/*tangentVectorEqualsSelf*/ isLinear);
2791-
}) != elements.end();
2792-
bool alreadyDiagnosedOneParam = false;
2793-
for (unsigned i = 0, end = inputRepr->getNumElements(); i != end; ++i) {
2794-
auto *eltTypeRepr = inputRepr->getElementType(i);
2795-
auto param = elements[i];
2796-
if (param.isNoDerivative())
2797-
continue;
2798-
auto paramType = param.getPlainType();
2799-
if (isDifferentiable(paramType, isLinear))
2800-
continue;
2801-
auto paramTypeString = paramType->getString();
2802-
auto diagnostic =
2803-
diagnose(eltTypeRepr->getLoc(),
2804-
diag::differentiable_function_type_invalid_parameter,
2805-
paramTypeString, isLinear, hasValidDifferentiabilityParam);
2806-
alreadyDiagnosedOneParam = true;
2807-
if (hasValidDifferentiabilityParam)
2808-
diagnostic.fixItInsert(eltTypeRepr->getLoc(), "@noDerivative ");
2809-
}
2810-
// Reject the case where all parameters have '@noDerivative'.
2811-
if (!alreadyDiagnosedOneParam && !hasValidDifferentiabilityParam) {
2812-
diagnose(
2813-
inputRepr->getLoc(),
2814-
diag::
2815-
differentiable_function_type_no_differentiability_parameters,
2816-
isLinear)
2817-
.highlight(inputRepr->getSourceRange());
2818-
}
2819-
}
2820-
28212772
return elements;
28222773
}
28232774

@@ -2946,59 +2897,14 @@ NeverNullType TypeResolver::resolveASTFunctionType(
29462897
if (fnTy->hasError())
29472898
return fnTy;
29482899

2949-
// If the type is a block or C function pointer, it must be representable in
2950-
// ObjC.
2951-
switch (representation) {
2952-
case AnyFunctionType::Representation::Block:
2953-
case AnyFunctionType::Representation::CFunctionPointer:
2954-
if (!fnTy->isRepresentableIn(ForeignLanguage::ObjectiveC,
2955-
getDeclContext())) {
2956-
StringRef strName =
2957-
(representation == AnyFunctionType::Representation::Block)
2958-
? "block"
2959-
: "c";
2960-
auto extInfo2 =
2961-
extInfo.withRepresentation(AnyFunctionType::Representation::Swift);
2962-
auto simpleFnTy = FunctionType::get(params, outputTy, extInfo2);
2963-
diagnose(repr->getStartLoc(), diag::objc_convention_invalid,
2964-
simpleFnTy, strName);
2965-
}
2966-
break;
2967-
2968-
case AnyFunctionType::Representation::Thin:
2969-
case AnyFunctionType::Representation::Swift:
2970-
break;
2971-
}
2972-
2973-
// `@differentiable` function types must return a differentiable type.
2974-
if (extInfo.isDifferentiable() &&
2975-
resolution.getStage() != TypeResolutionStage::Structural) {
2976-
bool isLinear = diffKind == DifferentiabilityKind::Linear;
2977-
if (!isDifferentiable(outputTy, /*tangentVectorEqualsSelf*/ isLinear)) {
2978-
diagnose(repr->getResultTypeRepr()->getLoc(),
2979-
diag::differentiable_function_type_invalid_result,
2980-
outputTy->getString(), isLinear)
2981-
.highlight(repr->getResultTypeRepr()->getSourceRange());
2982-
}
2983-
}
2900+
if (TypeChecker::diagnoseInvalidFunctionType(fnTy, repr->getLoc(), repr,
2901+
getDeclContext(),
2902+
resolution.getStage()))
2903+
return ErrorType::get(fnTy);
29842904

29852905
return fnTy;
29862906
}
29872907

2988-
bool TypeResolver::isDifferentiable(Type type, bool tangentVectorEqualsSelf) {
2989-
if (resolution.getStage() != TypeResolutionStage::Contextual)
2990-
type = getDeclContext()->mapTypeIntoContext(type);
2991-
auto tanSpace = type->getAutoDiffTangentSpace(
2992-
LookUpConformanceInModule(getDeclContext()->getParentModule()));
2993-
if (!tanSpace)
2994-
return false;
2995-
// If no `Self == Self.TangentVector` requirement, return true.
2996-
if (!tangentVectorEqualsSelf)
2997-
return true;
2998-
// Otherwise, return true if `Self == Self.TangentVector`.
2999-
return type->getCanonicalType() == tanSpace->getCanonicalType();
3000-
}
3001-
30022908
NeverNullType TypeResolver::resolveSILBoxType(SILBoxTypeRepr *repr,
30032909
TypeResolutionOptions options) {
30042910
// Resolve the field types.

lib/Sema/TypeChecker.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,136 @@ TypeChecker::getDeclTypeCheckingSemantics(ValueDecl *decl) {
490490
}
491491
return DeclTypeCheckingSemantics::Normal;
492492
}
493+
494+
bool TypeChecker::isDifferentiable(Type type, bool tangentVectorEqualsSelf,
495+
DeclContext *dc,
496+
Optional<TypeResolutionStage> stage) {
497+
if (stage && stage != TypeResolutionStage::Contextual)
498+
type = dc->mapTypeIntoContext(type);
499+
auto tanSpace = type->getAutoDiffTangentSpace(
500+
LookUpConformanceInModule(dc->getParentModule()));
501+
if (!tanSpace)
502+
return false;
503+
// If no `Self == Self.TangentVector` requirement, return true.
504+
if (!tangentVectorEqualsSelf)
505+
return true;
506+
// Otherwise, return true if `Self == Self.TangentVector`.
507+
return type->getCanonicalType() == tanSpace->getCanonicalType();
508+
}
509+
510+
bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
511+
Optional<FunctionTypeRepr *>repr,
512+
DeclContext *dc,
513+
Optional<TypeResolutionStage> stage) {
514+
// If the type has a placeholder, don't try to diagnose anything now since
515+
// we'll produce a better diagnostic when (if) the expression successfully
516+
// typechecks.
517+
if (fnTy->hasPlaceholder())
518+
return false;
519+
520+
// If the type is a block or C function pointer, it must be representable in
521+
// ObjC.
522+
auto representation = fnTy->getRepresentation();
523+
auto extInfo = fnTy->getExtInfo();
524+
auto &ctx = dc->getASTContext();
525+
526+
bool hadAnyError = false;
527+
528+
switch (representation) {
529+
case AnyFunctionType::Representation::Block:
530+
case AnyFunctionType::Representation::CFunctionPointer:
531+
if (!fnTy->isRepresentableIn(ForeignLanguage::ObjectiveC, dc)) {
532+
StringRef strName =
533+
(representation == AnyFunctionType::Representation::Block)
534+
? "block"
535+
: "c";
536+
auto extInfo2 =
537+
extInfo.withRepresentation(AnyFunctionType::Representation::Swift);
538+
auto simpleFnTy = FunctionType::get(fnTy->getParams(), fnTy->getResult(),
539+
extInfo2);
540+
ctx.Diags.diagnose(loc, diag::objc_convention_invalid,
541+
simpleFnTy, strName);
542+
hadAnyError = true;
543+
}
544+
break;
545+
546+
case AnyFunctionType::Representation::Thin:
547+
case AnyFunctionType::Representation::Swift:
548+
break;
549+
}
550+
551+
// `@differentiable` function types must return a differentiable type and have
552+
// differentiable (or `@noDerivative`) parameters.
553+
if (extInfo.isDifferentiable() &&
554+
stage != TypeResolutionStage::Structural) {
555+
auto result = fnTy->getResult();
556+
auto params = fnTy->getParams();
557+
auto diffKind = extInfo.getDifferentiabilityKind();
558+
bool isLinear = diffKind == DifferentiabilityKind::Linear;
559+
560+
// Check the params.
561+
562+
// Emit `@noDerivative` fixit only if there is at least one valid
563+
// differentiability parameter. Otherwise, adding `@noDerivative` produces
564+
// an ill-formed function type.
565+
auto hasValidDifferentiabilityParam =
566+
llvm::find_if(params, [&](AnyFunctionType::Param param) {
567+
if (param.isNoDerivative())
568+
return false;
569+
return TypeChecker::isDifferentiable(param.getPlainType(),
570+
/*tangentVectorEqualsSelf*/ isLinear,
571+
dc, stage);
572+
}) != params.end();
573+
bool alreadyDiagnosedOneParam = false;
574+
for (unsigned i = 0, end = fnTy->getNumParams(); i != end; ++i) {
575+
auto param = params[i];
576+
if (param.isNoDerivative())
577+
continue;
578+
auto paramType = param.getPlainType();
579+
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage))
580+
continue;
581+
auto diagLoc =
582+
repr ? (*repr)->getArgsTypeRepr()->getElement(i).Type->getLoc() : loc;
583+
auto paramTypeString = paramType->getString();
584+
auto diagnostic = ctx.Diags.diagnose(
585+
diagLoc, diag::differentiable_function_type_invalid_parameter,
586+
paramTypeString, isLinear, hasValidDifferentiabilityParam);
587+
alreadyDiagnosedOneParam = true;
588+
hadAnyError = true;
589+
if (hasValidDifferentiabilityParam)
590+
diagnostic.fixItInsert(diagLoc, "@noDerivative ");
591+
}
592+
// Reject the case where all parameters have '@noDerivative'.
593+
if (!alreadyDiagnosedOneParam && !hasValidDifferentiabilityParam) {
594+
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
595+
auto diag = ctx.Diags.diagnose(
596+
diagLoc,
597+
diag::differentiable_function_type_no_differentiability_parameters,
598+
isLinear);
599+
hadAnyError = true;
600+
601+
if (repr) {
602+
diag.highlight((*repr)->getSourceRange());
603+
}
604+
}
605+
606+
// Check the result
607+
bool differentiable = isDifferentiable(result,
608+
/*tangentVectorEqualsSelf*/ isLinear,
609+
dc, stage);
610+
if (!differentiable) {
611+
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
612+
auto resultStr = fnTy->getResult()->getString();
613+
auto diag = ctx.Diags.diagnose(
614+
diagLoc, diag::differentiable_function_type_invalid_result, resultStr,
615+
isLinear);
616+
hadAnyError = true;
617+
618+
if (repr) {
619+
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
620+
}
621+
}
622+
}
623+
624+
return hadAnyError;
625+
}

lib/Sema/TypeChecker.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,19 @@ bool typeSupportsBuilderOp(Type builderType, DeclContext *dc, Identifier fnName,
11641164
/// once.
11651165
void applyAccessNote(ValueDecl *VD);
11661166

1167+
/// Returns true if the given type conforms to `Differentiable` in the
1168+
/// module of `dc`. If `tangentVectorEqualsSelf` is true, returns true iff
1169+
/// the given type additionally satisfies `Self == Self.TangentVector`.
1170+
bool isDifferentiable(Type type, bool tangentVectorEqualsSelf, DeclContext *dc,
1171+
Optional<TypeResolutionStage> stage);
1172+
1173+
/// Emits diagnostics if the given function type's parameter/result types are
1174+
/// not compatible with the ext info. Returns whether an error was diagnosed.
1175+
bool diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
1176+
Optional<FunctionTypeRepr *>repr,
1177+
DeclContext *dc,
1178+
Optional<TypeResolutionStage> stage);
1179+
11671180
}; // namespace TypeChecker
11681181

11691182
/// Returns the protocol requirement kind of the given declaration.

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,16 @@ let _: @differentiable(reverse) (Float, Float) -> TF_521<Float> = { r, i in
400400
TF_521(real: r, imaginary: i)
401401
}
402402

403+
// expected-error @+1 {{result type 'TF_521<Float>' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
404+
let _: @differentiable(reverse) (_, _) -> TF_521<Float> = { (r: Float, i: Float) in
405+
TF_521(real: r, imaginary: i)
406+
}
407+
408+
// expected-error @+1 {{result type 'TF_521<Float>' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
409+
let _: @differentiable(reverse) (Float, Float) -> _ = { r, i in
410+
TF_521(real: r, imaginary: i)
411+
}
412+
403413
// TF-296: Infer `@differentiable` wrt parameters to be to all parameters that conform to `Differentiable`.
404414

405415
@differentiable(reverse)

test/AutoDiff/Sema/differentiable_func_type.swift

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ let _: @differentiable(_linear) (Float) -> Float
4545
func test1<T: Differentiable, U: Differentiable>(_: @differentiable(reverse) (T) -> @differentiable(reverse) (U) -> Float) {}
4646
// expected-error @+1 {{result type '(U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
4747
func test2<T: Differentiable, U: Differentiable>(_: @differentiable(reverse) (T) -> (U) -> Float) {}
48-
// expected-error @+2 {{result type 'Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
49-
// expected-error @+1 {{result type '@differentiable(reverse) (U) -> Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
48+
// expected-error @+1 {{result type 'Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
5049
func test3<T: Differentiable, U: Differentiable>(_: @differentiable(reverse) (T) -> @differentiable(reverse) (U) -> Int) {}
5150
// expected-error @+1 {{result type '(U) -> Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
5251
func test4<T: Differentiable, U: Differentiable>(_: @differentiable(reverse) (T) -> (U) -> Int) {}
@@ -189,7 +188,6 @@ extension Vector: Differentiable where T: Differentiable {
189188
// expected-note@+1 2 {{found this candidate}}
190189
func inferredConformancesGeneric<T, U>(_: @differentiable(reverse) (Vector<T>) -> Vector<U>) {}
191190

192-
// expected-note @+5 2 {{found this candidate}}
193191
// expected-error @+4 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
194192
// expected-error @+3 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
195193
// expected-error @+2 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(_linear)'}}
@@ -203,7 +201,8 @@ func nondiff(x: Vector<Int>) -> Vector<Int> {}
203201

204202
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGeneric'}}
205203
inferredConformancesGeneric(nondiff)
206-
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGenericLinear'}}
204+
// expected-error @+2 {{global function 'inferredConformancesGenericLinear' requires that 'Int' conform to 'Differentiable'}}
205+
// expected-error @+1 {{global function 'inferredConformancesGenericLinear' requires that 'Int' conform to 'Differentiable'}}
207206
inferredConformancesGenericLinear(nondiff)
208207

209208
func diff(x: Vector<Float>) -> Vector<Float> {}
@@ -228,13 +227,15 @@ extension Linear: Differentiable where T: Differentiable, T == T.TangentVector {
228227
// expected-note @+1 2 {{found this candidate}}
229228
func inferredConformancesGeneric<T, U>(_: @differentiable(reverse) (Linear<T>) -> Linear<U>) {}
230229

231-
// expected-note @+1 2 {{found this candidate}}
230+
// expected-note @+2 2 {{where 'T' = 'Int'}}
231+
// expected-note @+1 2 {{where 'U' = 'Int'}}
232232
func inferredConformancesGenericLinear<T, U>(_: @differentiable(_linear) (Linear<T>) -> Linear<U>) {}
233233

234234
func nondiff(x: Linear<Int>) -> Linear<Int> {}
235235
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGeneric'}}
236236
inferredConformancesGeneric(nondiff)
237-
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGenericLinear'}}
237+
// expected-error @+2 {{global function 'inferredConformancesGenericLinear' requires that 'Int' conform to 'Differentiable'}}
238+
// expected-error @+1 {{global function 'inferredConformancesGenericLinear' requires that 'Int' conform to 'Differentiable'}}
238239
inferredConformancesGenericLinear(nondiff)
239240

240241
func diff(x: Linear<Float>) -> Linear<Float> {}

0 commit comments

Comments
 (0)