Skip to content

Commit 4b3d2f4

Browse files
committed
[Typed throws] Handle function conversions involving different thrown errors
Teach the constraint solver about the subtyping rule that permits converting one function type to another when the effective thrown error type of one is a subtype of the effective thrown error type of the other, using `any Error` for untyped throws and `Never` for non-throwing. With minor other fixes, this allows us to use typed throws for generic functions that carry a typed error from their arguments through to themselves, which is in effect a typed `rethrows`: ```swift func mapArray<T, U, E: Error>(_ array: [T], body: (T) throws(E) -> U) throws(E) -> [U] { var resultArray: [U] = .init() for value in array { resultArray.append(try body(value)) } return resultArray } ```
1 parent d58943a commit 4b3d2f4

14 files changed

+239
-12
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,9 @@ ERROR(throws_functiontype_mismatch,none,
615615
"invalid conversion from throwing function of type %0 to "
616616
"non-throwing function type %1", (Type, Type))
617617

618+
ERROR(thrown_error_type_mismatch,none,
619+
"invalid conversion of thrown error type %0 to %1", (Type, Type))
620+
618621
ERROR(async_functiontype_mismatch,none,
619622
"invalid conversion from 'async' function of type %0 to "
620623
"synchronous function type %1", (Type, Type))

include/swift/Sema/CSFix.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ enum class FixKind : uint8_t {
375375
/// `throws` attribute from the source function.
376376
DropThrowsAttribute,
377377

378+
/// Ignore a mismatch in the thrown error type.
379+
IgnoreThrownErrorMismatch,
380+
378381
/// Fix conversion from async to sync function by removing explicit
379382
/// `async` attribute from the source function.
380383
DropAsyncAttribute,
@@ -1005,7 +1008,6 @@ class DropThrowsAttribute final : public ContextualMismatch {
10051008
FunctionType *toType, ConstraintLocator *locator)
10061009
: ContextualMismatch(cs, FixKind::DropThrowsAttribute, fromType, toType,
10071010
locator) {
1008-
assert(fromType->isThrowing() != toType->isThrowing());
10091011
}
10101012

10111013
public:
@@ -1023,6 +1025,30 @@ class DropThrowsAttribute final : public ContextualMismatch {
10231025
}
10241026
};
10251027

1028+
/// This is a contextual mismatch between the thrown error types of two
1029+
/// function types, which could be repaired by fixing one of the types.
1030+
class IgnoreThrownErrorMismatch final : public ContextualMismatch {
1031+
IgnoreThrownErrorMismatch(ConstraintSystem &cs, Type fromErrorType,
1032+
Type toErrorType, ConstraintLocator *locator)
1033+
: ContextualMismatch(cs, FixKind::IgnoreThrownErrorMismatch,
1034+
fromErrorType, toErrorType, locator) {
1035+
assert(!fromErrorType->isEqual(toErrorType));
1036+
}
1037+
1038+
public:
1039+
std::string getName() const override { return "drop 'throws' attribute"; }
1040+
1041+
bool diagnose(const Solution &solution, bool asNote = false) const override;
1042+
1043+
static IgnoreThrownErrorMismatch *create(ConstraintSystem &cs,
1044+
Type fromErrorType,
1045+
Type toErrorType,
1046+
ConstraintLocator *locator);
1047+
1048+
static bool classof(const ConstraintFix *fix) {
1049+
return fix->getKind() == FixKind::IgnoreThrownErrorMismatch;
1050+
}
1051+
};
10261052
/// This is a contextual mismatch between async and non-async
10271053
/// function types, repair it by dropping `async` attribute.
10281054
class DropAsyncAttribute final : public ContextualMismatch {

include/swift/Sema/ConstraintLocatorPathElts.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ ABSTRACT_LOCATOR_PATH_ELT(PatternDecl)
275275
/// A function type global actor.
276276
SIMPLE_LOCATOR_PATH_ELT(GlobalActorType)
277277

278+
/// The thrown error of a function type.
279+
SIMPLE_LOCATOR_PATH_ELT(ThrownErrorType)
280+
278281
/// A type coercion operand.
279282
SIMPLE_LOCATOR_PATH_ELT(CoercionOperand)
280283

lib/AST/TypeWalker.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ class Traversal : public TypeVisitor<Traversal, bool>
115115
return true;
116116
}
117117

118+
if (Type thrownError = ty->getThrownError()) {
119+
if (doIt(thrownError))
120+
return true;
121+
}
122+
118123
return doIt(ty->getResult());
119124
}
120125

lib/Sema/CSDiagnostics.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7096,6 +7096,12 @@ bool ThrowingFunctionConversionFailure::diagnoseAsError() {
70967096
return true;
70977097
}
70987098

7099+
bool ThrownErrorTypeConversionFailure::diagnoseAsError() {
7100+
emitDiagnostic(diag::thrown_error_type_mismatch, getFromType(),
7101+
getToType());
7102+
return true;
7103+
}
7104+
70997105
bool AsyncFunctionConversionFailure::diagnoseAsError() {
71007106
auto *locator = getLocator();
71017107

lib/Sema/CSDiagnostics.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,24 @@ class ThrowingFunctionConversionFailure final : public ContextualFailure {
933933
bool diagnoseAsError() override;
934934
};
935935

936+
/// Diagnose failures related to conversion between the thrown error type
937+
/// of two function types, e.g.,
938+
///
939+
/// ```swift
940+
/// func foo<T>(_ t: T) throws(MyError) -> Void {}
941+
/// let _: (Int) throws (OtherError)-> Void = foo
942+
/// // `MyError` can't be implicitly converted to `OtherError`
943+
/// ```
944+
class ThrownErrorTypeConversionFailure final : public ContextualFailure {
945+
public:
946+
ThrownErrorTypeConversionFailure(const Solution &solution, Type fromType,
947+
Type toType, ConstraintLocator *locator)
948+
: ContextualFailure(solution, fromType, toType, locator) {
949+
}
950+
951+
bool diagnoseAsError() override;
952+
};
953+
936954
/// Diagnose failures related to conversion between 'async' function type
937955
/// and a synchronous one e.g.
938956
///

lib/Sema/CSFix.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,21 @@ DropThrowsAttribute *DropThrowsAttribute::create(ConstraintSystem &cs,
16271627
DropThrowsAttribute(cs, fromType, toType, locator);
16281628
}
16291629

1630+
bool IgnoreThrownErrorMismatch::diagnose(const Solution &solution,
1631+
bool asNote) const {
1632+
ThrownErrorTypeConversionFailure failure(solution, getFromType(),
1633+
getToType(), getLocator());
1634+
return failure.diagnose(asNote);
1635+
}
1636+
1637+
IgnoreThrownErrorMismatch *IgnoreThrownErrorMismatch::create(ConstraintSystem &cs,
1638+
Type fromErrorType,
1639+
Type toErrorType,
1640+
ConstraintLocator *locator) {
1641+
return new (cs.getAllocator())
1642+
IgnoreThrownErrorMismatch(cs, fromErrorType, toErrorType, locator);
1643+
}
1644+
16301645
bool DropAsyncAttribute::diagnose(const Solution &solution,
16311646
bool asNote) const {
16321647
AsyncFunctionConversionFailure failure(solution, getFromType(),

lib/Sema/CSGen.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2453,7 +2453,6 @@ namespace {
24532453
auto resultLocator =
24542454
CS.getConstraintLocator(closure, ConstraintLocator::ClosureResult);
24552455

2456-
// FIXME: Need a better locator.
24572456
auto thrownErrorLocator =
24582457
CS.getConstraintLocator(closure, ConstraintLocator::ClosureThrownError);
24592458

lib/Sema/CSSimplify.cpp

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2945,14 +2945,99 @@ bool ConstraintSystem::hasPreconcurrencyCallee(
29452945
return calleeOverload->choice.getDecl()->preconcurrency();
29462946
}
29472947

2948+
namespace {
2949+
/// Classifies a thrown error kind as Never, a specific type, or 'any Error'.
2950+
enum class ThrownErrorKind {
2951+
Never,
2952+
Specific,
2953+
AnyError,
2954+
};
2955+
2956+
ThrownErrorKind getThrownErrorKind(Type type) {
2957+
if (type->isNever())
2958+
return ThrownErrorKind::Never;
2959+
2960+
if (type->isExistentialType()) {
2961+
Type anyError = type->getASTContext().getErrorExistentialType();
2962+
if (anyError->isEqual(type))
2963+
return ThrownErrorKind::AnyError;
2964+
}
2965+
2966+
return ThrownErrorKind::Specific;
2967+
}
2968+
}
2969+
29482970
ConstraintSystem::TypeMatchResult
29492971
ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
29502972
ConstraintKind kind, TypeMatchOptions flags,
29512973
ConstraintLocatorBuilder locator) {
2952-
// A non-throwing function can be a subtype of a throwing function.
2953-
if (func1->isThrowing() != func2->isThrowing()) {
2954-
// Cannot drop 'throws'.
2955-
if (func1->isThrowing() || kind < ConstraintKind::Subtype) {
2974+
// A function type that throws the error type E1 is a subtype of a function
2975+
// that throws error type E2 when E1 is a subtype of E2. For the purpose
2976+
// of this comparison, a non-throwing function has thrown error type 'Never',
2977+
// and an untyped throwing function has thrown error type 'any Error'.
2978+
Type neverType = getASTContext().getNeverType();
2979+
Type thrownError1 = func1->getEffectiveThrownInterfaceType().value_or(neverType);
2980+
Type thrownError2 = func2->getEffectiveThrownInterfaceType().value_or(neverType);
2981+
if (!thrownError1->isEqual(thrownError2)) {
2982+
auto thrownErrorKind1 = getThrownErrorKind(thrownError1);
2983+
auto thrownErrorKind2 = getThrownErrorKind(thrownError2);
2984+
2985+
bool mustUnify = false;
2986+
bool dropThrows = false;
2987+
2988+
switch (thrownErrorKind1) {
2989+
case ThrownErrorKind::Specific:
2990+
// If the specific thrown error contains no type variables and we're
2991+
// going to try to convert it to \c Never, treat this as dropping throws.
2992+
if (thrownErrorKind2 == ThrownErrorKind::Never &&
2993+
!thrownError1->hasTypeVariable()) {
2994+
dropThrows = true;
2995+
} else {
2996+
// We need to unify the thrown error types.
2997+
mustUnify = true;
2998+
}
2999+
break;
3000+
3001+
case ThrownErrorKind::Never:
3002+
switch (thrownErrorKind2) {
3003+
case ThrownErrorKind::Specific:
3004+
// We need to unify the thrown error types.
3005+
mustUnify = true;
3006+
break;
3007+
3008+
case ThrownErrorKind::Never:
3009+
llvm_unreachable("The thrown error types should have been equal");
3010+
break;
3011+
3012+
case ThrownErrorKind::AnyError:
3013+
// We have a subtype. If we're not allowed to do the subtype,
3014+
// then we need to drop "throws".
3015+
if (kind < ConstraintKind::Subtype)
3016+
dropThrows = true;
3017+
break;
3018+
}
3019+
break;
3020+
3021+
case ThrownErrorKind::AnyError:
3022+
switch (thrownErrorKind2) {
3023+
case ThrownErrorKind::Specific:
3024+
// We need to unify the thrown error types.
3025+
mustUnify = true;
3026+
break;
3027+
3028+
case ThrownErrorKind::Never:
3029+
// We're going to have to drop the "throws" entirely.
3030+
dropThrows = true;
3031+
break;
3032+
3033+
case ThrownErrorKind::AnyError:
3034+
llvm_unreachable("The thrown error types should have been equal");
3035+
}
3036+
break;
3037+
}
3038+
3039+
// If we know we need to drop 'throws', try it now.
3040+
if (dropThrows) {
29563041
if (!shouldAttemptFixes())
29573042
return getTypeMatchFailure(locator);
29583043

@@ -2961,6 +3046,20 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
29613046
if (recordFix(fix))
29623047
return getTypeMatchFailure(locator);
29633048
}
3049+
3050+
// If we need to unify the thrown error types, do so now.
3051+
if (mustUnify) {
3052+
ConstraintKind subKind = (kind < ConstraintKind::Subtype)
3053+
? ConstraintKind::Equal
3054+
: ConstraintKind::Subtype;
3055+
const auto subflags = getDefaultDecompositionOptions(flags);
3056+
auto result = matchTypes(
3057+
thrownError1, thrownError2,
3058+
subKind, subflags,
3059+
locator.withPathElement(LocatorPathElt::ThrownErrorType()));
3060+
if (result == SolutionKind::Error)
3061+
return getTypeMatchFailure(locator);
3062+
}
29643063
}
29653064

29663065
// A synchronous function can be a subtype of an 'async' function.
@@ -5152,6 +5251,13 @@ bool ConstraintSystem::repairFailures(
51525251
getConstraintLocator(locator)))
51535252
return true;
51545253

5254+
if (locator.endsWith<LocatorPathElt::ThrownErrorType>()) {
5255+
conversionsOrFixes.push_back(
5256+
IgnoreThrownErrorMismatch::create(*this, lhs, rhs,
5257+
getConstraintLocator(locator)));
5258+
return true;
5259+
}
5260+
51555261
if (path.empty()) {
51565262
if (!anchor)
51575263
return false;
@@ -14999,6 +15105,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint(
1499915105
case FixKind::AllowAssociatedValueMismatch:
1500015106
case FixKind::GenericArgumentsMismatch:
1500115107
case FixKind::AllowConcreteTypeSpecialization:
15108+
case FixKind::IgnoreThrownErrorMismatch:
1500215109
case FixKind::IgnoreGenericSpecializationArityMismatch: {
1500315110
return recordFix(fix) ? SolutionKind::Error : SolutionKind::Solved;
1500415111
}

lib/Sema/ConstraintLocator.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ unsigned LocatorPathElt::getNewSummaryFlags() const {
109109
case ConstraintLocator::GlobalActorType:
110110
case ConstraintLocator::CoercionOperand:
111111
case ConstraintLocator::PackExpansionType:
112+
case ConstraintLocator::ThrownErrorType:
112113
return 0;
113114

114115
case ConstraintLocator::FunctionArgument:
@@ -519,6 +520,10 @@ void LocatorPathElt::dump(raw_ostream &out) const {
519520
<< expansionElt.getOpenedType()->getString(PO) << ")";
520521
break;
521522
}
523+
case ConstraintLocator::ThrownErrorType: {
524+
out << "thrown error type";
525+
break;
526+
}
522527
}
523528
}
524529

0 commit comments

Comments
 (0)