Skip to content

Commit 8cfdb99

Browse files
authored
Merge pull request #41436 from xedin/allow-specialization-from-default-expr
[TypeChecker] Allow inference from default expressions in certain scenarios (under a flag)
2 parents 48de949 + eaa737c commit 8cfdb99

25 files changed

+994
-53
lines changed

include/swift/AST/Decl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5504,6 +5504,7 @@ enum class ParamSpecifier : uint8_t {
55045504
class ParamDecl : public VarDecl {
55055505
friend class DefaultArgumentInitContextRequest;
55065506
friend class DefaultArgumentExprRequest;
5507+
friend class DefaultArgumentTypeRequest;
55075508

55085509
enum class ArgumentNameFlags : uint8_t {
55095510
/// Whether or not this parameter is destructed.
@@ -5524,6 +5525,9 @@ class ParamDecl : public VarDecl {
55245525
struct alignas(1 << StoredDefaultArgumentAlignInBits) StoredDefaultArgument {
55255526
PointerUnion<Expr *, VarDecl *> DefaultArg;
55265527

5528+
/// The type of the default argument expression.
5529+
Type ExprType;
5530+
55275531
/// Stores the context for the default argument as well as a bit to
55285532
/// indicate whether the default expression has been type-checked.
55295533
llvm::PointerIntPair<Initializer *, 1, bool> InitContextAndIsTypeChecked;
@@ -5641,6 +5645,10 @@ class ParamDecl : public VarDecl {
56415645
return nullptr;
56425646
}
56435647

5648+
/// Retrieve the type of the default expression (if any) associated with
5649+
/// this parameter declaration.
5650+
Type getTypeOfDefaultExpr() const;
5651+
56445652
VarDecl *getStoredProperty() const {
56455653
if (auto stored = DefaultValueAndFlags.getPointer())
56465654
return stored->DefaultArg.dyn_cast<VarDecl *>();
@@ -5655,6 +5663,10 @@ class ParamDecl : public VarDecl {
56555663
/// parameter's fully type-checked default argument.
56565664
void setDefaultExpr(Expr *E, bool isTypeChecked);
56575665

5666+
/// Sets a type of default expression associated with this parameter.
5667+
/// This should only be called by deserialization.
5668+
void setDefaultExprType(Type type);
5669+
56585670
void setStoredProperty(VarDecl *var);
56595671

56605672
/// Retrieve the initializer context for the parameter's default argument.

include/swift/AST/DiagnosticsSema.def

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ NOTE(extended_type_declared_here,none,
3838
"extended type declared here", ())
3939
NOTE(opaque_return_type_declared_here,none,
4040
"opaque return type declared here", ())
41+
NOTE(default_value_declared_here,none,
42+
"default value declared here", ())
4143

4244
//------------------------------------------------------------------------------
4345
// MARK: Constraint solver diagnostics
@@ -6240,5 +6242,28 @@ ERROR(type_sequence_on_non_generic_param, none,
62406242
"'@_typeSequence' must appear on a generic parameter",
62416243
())
62426244

6245+
//------------------------------------------------------------------------------
6246+
// MARK: Type inference from default expressions
6247+
//------------------------------------------------------------------------------
6248+
6249+
ERROR(cannot_default_generic_parameter_inferrable_from_another_parameter, none,
6250+
"cannot use default expression for inference of %0 because it "
6251+
"is inferrable from parameters %1",
6252+
(Type, StringRef))
6253+
6254+
ERROR(cannot_default_generic_parameter_inferrable_through_same_type, none,
6255+
"cannot use default expression for inference of %0 because it "
6256+
"is inferrable through same-type requirement: '%1'",
6257+
(Type, StringRef))
6258+
6259+
ERROR(cannot_default_generic_parameter_invalid_requirement, none,
6260+
"cannot use default expression for inference of %0 because "
6261+
"requirement '%1' refers to other generic parameters",
6262+
(Type, StringRef))
6263+
6264+
ERROR(cannot_convert_default_value_type_to_argument_type, none,
6265+
"cannot convert default value of type %0 to expected argument type %1 for parameter #%2",
6266+
(Type, Type, unsigned))
6267+
62436268
#define UNDEFINE_DIAGNOSTIC_MACROS
62446269
#include "DefineDiagnosticMacros.h"

include/swift/AST/TypeCheckRequests.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class AccessorDecl;
4141
enum class AccessorKind;
4242
class ContextualPattern;
4343
class DefaultArgumentExpr;
44+
class DefaultArgumentType;
4445
class ClosureExpr;
4546
class GenericParamList;
4647
class PrecedenceGroupDecl;
@@ -2665,6 +2666,26 @@ class DefaultArgumentExprRequest
26652666
void cacheResult(Expr *expr) const;
26662667
};
26672668

2669+
/// Computes the type of the default expression for a given parameter.
2670+
class DefaultArgumentTypeRequest
2671+
: public SimpleRequest<DefaultArgumentTypeRequest, Type(ParamDecl *),
2672+
RequestFlags::SeparatelyCached> {
2673+
public:
2674+
using SimpleRequest::SimpleRequest;
2675+
2676+
private:
2677+
friend SimpleRequest;
2678+
2679+
// Evaluation.
2680+
Type evaluate(Evaluator &evaluator, ParamDecl *param) const;
2681+
2682+
public:
2683+
// Separate caching.
2684+
bool isCached() const { return true; }
2685+
Optional<Type> getCachedResult() const;
2686+
void cacheResult(Type type) const;
2687+
};
2688+
26682689
/// Computes the fully type-checked caller-side default argument within the
26692690
/// context of the call site that it will be inserted into.
26702691
class CallerSideDefaultArgExprRequest

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ SWIFT_REQUEST(TypeChecker, CustomAttrTypeRequest,
5959
SeparatelyCached, NoLocationInfo)
6060
SWIFT_REQUEST(TypeChecker, DefaultArgumentExprRequest,
6161
Expr *(ParamDecl *), SeparatelyCached, NoLocationInfo)
62+
SWIFT_REQUEST(TypeChecker, DefaultArgumentTypeRequest,
63+
Type(ParamDecl *), SeparatelyCached, NoLocationInfo)
6264
SWIFT_REQUEST(TypeChecker, DefaultArgumentInitContextRequest,
6365
Initializer *(ParamDecl *), SeparatelyCached, NoLocationInfo)
6466
SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,

include/swift/Basic/LangOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,10 @@ namespace swift {
718718
/// closures.
719719
bool EnableMultiStatementClosureInference = false;
720720

721+
/// Enable experimental support for generic parameter inference in
722+
/// parameter positions from associated default expressions.
723+
bool EnableTypeInferenceFromDefaultArguments = false;
724+
721725
/// See \ref FrontendOptions.PrintFullConvention
722726
bool PrintFullConvention = false;
723727
};

include/swift/Option/FrontendOptions.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,10 @@ def experimental_multi_statement_closures :
861861
Flag<["-"], "experimental-multi-statement-closures">,
862862
HelpText<"Enable experimental support for type inference in multi-statement closures">;
863863

864+
def experimental_type_inference_from_defaults :
865+
Flag<["-"], "enable-experimental-type-inference-from-defaults">,
866+
HelpText<"Enable experimental support for generic parameter inference from default values">;
867+
864868
def prebuilt_module_cache_path :
865869
Separate<["-"], "prebuilt-module-cache-path">,
866870
HelpText<"Directory of prebuilt modules for loading module interfaces">;

include/swift/Sema/CSFix.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ enum class FixKind : uint8_t {
382382

383383
/// Produce an error for not getting a compile-time constant
384384
NotCompileTimeConst,
385+
386+
/// Ignore a type mismatch while trying to infer generic parameter type
387+
/// from default expression.
388+
IgnoreDefaultExprTypeMismatch,
385389
};
386390

387391
class ConstraintFix {
@@ -2896,6 +2900,29 @@ class AllowSwiftToCPointerConversion final : public ConstraintFix {
28962900
ConstraintLocator *locator);
28972901
};
28982902

2903+
class IgnoreDefaultExprTypeMismatch : public AllowArgumentMismatch {
2904+
protected:
2905+
IgnoreDefaultExprTypeMismatch(ConstraintSystem &cs, Type argType,
2906+
Type paramType, ConstraintLocator *locator)
2907+
: AllowArgumentMismatch(cs, FixKind::IgnoreDefaultExprTypeMismatch,
2908+
argType, paramType, locator) {}
2909+
2910+
public:
2911+
std::string getName() const override {
2912+
return "allow default expression conversion mismatch";
2913+
}
2914+
2915+
bool diagnose(const Solution &solution, bool asNote = false) const override;
2916+
2917+
static IgnoreDefaultExprTypeMismatch *create(ConstraintSystem &cs,
2918+
Type argType, Type paramType,
2919+
ConstraintLocator *locator);
2920+
2921+
static bool classof(const ConstraintFix *fix) {
2922+
return fix->getKind() == FixKind::IgnoreDefaultExprTypeMismatch;
2923+
}
2924+
};
2925+
28992926
} // end namespace constraints
29002927
} // end namespace swift
29012928

include/swift/Sema/ConstraintSystem.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ Optional<constraints::SolutionApplicationTarget>
8080
typeCheckExpression(constraints::SolutionApplicationTarget &target,
8181
OptionSet<TypeCheckExprFlags> options);
8282

83+
Type typeCheckParameterDefault(Expr *&, DeclContext *, Type, bool);
84+
8385
} // end namespace TypeChecker
8486

8587
} // end namespace swift
@@ -3045,6 +3047,10 @@ class ConstraintSystem {
30453047
swift::TypeChecker::typeCheckExpression(
30463048
SolutionApplicationTarget &target, OptionSet<TypeCheckExprFlags> options);
30473049

3050+
friend Type swift::TypeChecker::typeCheckParameterDefault(Expr *&,
3051+
DeclContext *, Type,
3052+
bool);
3053+
30483054
/// Emit the fixes computed as part of the solution, returning true if we were
30493055
/// able to emit an error message, or false if none of the fixits worked out.
30503056
bool applySolutionFixes(const Solution &solution);
@@ -4167,6 +4173,13 @@ class ConstraintSystem {
41674173
OpenedTypeMap &replacements,
41684174
ConstraintLocatorBuilder locator);
41694175

4176+
/// Open a generic parameter into a type variable and record
4177+
/// it in \c replacements.
4178+
TypeVariableType *openGenericParameter(DeclContext *outerDC,
4179+
GenericTypeParamType *parameter,
4180+
OpenedTypeMap &replacements,
4181+
ConstraintLocatorBuilder locator);
4182+
41704183
/// Given generic signature open its generic requirements,
41714184
/// using substitution function, and record them in the
41724185
/// constraint system for further processing.
@@ -4176,6 +4189,14 @@ class ConstraintSystem {
41764189
ConstraintLocatorBuilder locator,
41774190
llvm::function_ref<Type(Type)> subst);
41784191

4192+
// Record the given requirement in the constraint system.
4193+
void openGenericRequirement(DeclContext *outerDC,
4194+
unsigned index,
4195+
const Requirement &requirement,
4196+
bool skipProtocolSelfConstraint,
4197+
ConstraintLocatorBuilder locator,
4198+
llvm::function_ref<Type(Type)> subst);
4199+
41794200
/// Record the set of opened types for the given locator.
41804201
void recordOpenedTypes(
41814202
ConstraintLocatorBuilder locator,

lib/AST/Decl.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6992,6 +6992,18 @@ Expr *ParamDecl::getTypeCheckedDefaultExpr() const {
69926992
return new (ctx) ErrorExpr(getSourceRange(), ErrorType::get(ctx));
69936993
}
69946994

6995+
Type ParamDecl::getTypeOfDefaultExpr() const {
6996+
auto &ctx = getASTContext();
6997+
6998+
if (Type type = evaluateOrDefault(
6999+
ctx.evaluator,
7000+
DefaultArgumentTypeRequest{const_cast<ParamDecl *>(this)}, nullptr)) {
7001+
return type;
7002+
}
7003+
7004+
return Type();
7005+
}
7006+
69957007
void ParamDecl::setDefaultExpr(Expr *E, bool isTypeChecked) {
69967008
if (!DefaultValueAndFlags.getPointer()) {
69977009
if (!E) return;
@@ -7009,9 +7021,20 @@ void ParamDecl::setDefaultExpr(Expr *E, bool isTypeChecked) {
70097021
"Can't overwrite type-checked default with un-type-checked default");
70107022
}
70117023
defaultInfo->DefaultArg = E;
7024+
defaultInfo->ExprType = E->getType();
70127025
defaultInfo->InitContextAndIsTypeChecked.setInt(isTypeChecked);
70137026
}
70147027

7028+
void ParamDecl::setDefaultExprType(Type type) {
7029+
if (!DefaultValueAndFlags.getPointer()) {
7030+
DefaultValueAndFlags.setPointer(
7031+
getASTContext().Allocate<StoredDefaultArgument>());
7032+
}
7033+
7034+
auto *defaultInfo = DefaultValueAndFlags.getPointer();
7035+
defaultInfo->ExprType = type;
7036+
}
7037+
70157038
void ParamDecl::setStoredProperty(VarDecl *var) {
70167039
if (!DefaultValueAndFlags.getPointer()) {
70177040
if (!var) return;

lib/AST/TypeCheckRequests.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,27 @@ void DefaultArgumentExprRequest::cacheResult(Expr *expr) const {
12061206
param->setDefaultExpr(expr, /*isTypeChecked*/ true);
12071207
}
12081208

1209+
//----------------------------------------------------------------------------//
1210+
// DefaultArgumentTypeRequest computation.
1211+
//----------------------------------------------------------------------------//
1212+
1213+
Optional<Type> DefaultArgumentTypeRequest::getCachedResult() const {
1214+
auto *param = std::get<0>(getStorage());
1215+
auto *defaultInfo = param->DefaultValueAndFlags.getPointer();
1216+
if (!defaultInfo)
1217+
return None;
1218+
1219+
if (!defaultInfo->InitContextAndIsTypeChecked.getInt())
1220+
return None;
1221+
1222+
return defaultInfo->ExprType;
1223+
}
1224+
1225+
void DefaultArgumentTypeRequest::cacheResult(Type type) const {
1226+
auto *param = std::get<0>(getStorage());
1227+
param->setDefaultExprType(type);
1228+
}
1229+
12091230
//----------------------------------------------------------------------------//
12101231
// CallerSideDefaultArgExprRequest computation.
12111232
//----------------------------------------------------------------------------//

0 commit comments

Comments
 (0)