Skip to content

Commit 8adee85

Browse files
committed
[Macros] Add support for explicit generic arguments of macros.
Enable type checking support for explicitly specifying generic arguments to a macro, e.g., `#stringify<Double>(1 + 2)`. To do so, introduce a new kind of constraint that performs explicit argument matching against the generic parameters of a macro only after the overload is chosen.
1 parent afbc4a5 commit 8adee85

File tree

8 files changed

+167
-7
lines changed

8 files changed

+167
-7
lines changed

include/swift/Sema/Constraint.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,13 @@ enum class ConstraintKind : char {
226226
BindTupleOfFunctionParams,
227227
/// The first type is a type pack, and the second type is its reduced shape.
228228
ShapeOf,
229+
/// Represents explicit generic arguments provided for a reference to
230+
/// a declaration.
231+
///
232+
/// The first type is the type variable describing the bound type of
233+
/// an overload. The second type is a PackType containing the explicit
234+
/// generic arguments.
235+
ExplicitGenericArguments,
229236
};
230237

231238
/// Classification of the different kinds of constraints.
@@ -708,6 +715,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
708715
case ConstraintKind::Defaultable:
709716
case ConstraintKind::BindTupleOfFunctionParams:
710717
case ConstraintKind::ShapeOf:
718+
case ConstraintKind::ExplicitGenericArguments:
711719
return ConstraintClassification::TypeProperty;
712720

713721
case ConstraintKind::Disjunction:

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5611,6 +5611,13 @@ class ConstraintSystem {
56115611
Type type1, Type type2, TypeMatchOptions flags,
56125612
ConstraintLocatorBuilder locator);
56135613

5614+
/// Simplify an explicit generic argument constraint by equating the
5615+
/// opened generic types of the bound left-hand type variable to the
5616+
/// pack type on the right-hand side.
5617+
SolutionKind simplifyExplicitGenericArgumentsConstraint(
5618+
Type type1, Type type2, TypeMatchOptions flags,
5619+
ConstraintLocatorBuilder locator);
5620+
56145621
public: // FIXME: Public for use by static functions.
56155622
/// Simplify a conversion constraint with a fix applied to it.
56165623
SolutionKind simplifyFixConstraint(ConstraintFix *fix, Type type1, Type type2,

lib/Sema/CSBindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,7 @@ void PotentialBindings::infer(Constraint *constraint) {
14061406
case ConstraintKind::BindTupleOfFunctionParams:
14071407
case ConstraintKind::PackElementOf:
14081408
case ConstraintKind::ShapeOf:
1409+
case ConstraintKind::ExplicitGenericArguments:
14091410
// Constraints from which we can't do anything.
14101411
break;
14111412

lib/Sema/CSGen.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,30 @@ namespace {
16421642
expr->getOuterAlternatives());
16431643
}
16441644

1645+
/// Given a set of specialization arguments, resolve those arguments and
1646+
/// introduce them as an explicit generic arguments constraint.
1647+
void addSpecializationConstraint(
1648+
ConstraintLocator *locator, Type boundType,
1649+
ArrayRef<TypeRepr *> specializationArgs) {
1650+
// Resolve each type.
1651+
SmallVector<Type, 2> specializationArgTypes;
1652+
const auto options =
1653+
TypeResolutionOptions(TypeResolverContext::InExpression);
1654+
for (auto specializationArg : specializationArgs) {
1655+
const auto result = TypeResolution::resolveContextualType(
1656+
specializationArg, CurDC, options,
1657+
// Introduce type variables for unbound generics.
1658+
OpenUnboundGenericType(CS, locator),
1659+
HandlePlaceholderType(CS, locator));
1660+
specializationArgTypes.push_back(result);
1661+
}
1662+
1663+
CS.addConstraint(
1664+
ConstraintKind::ExplicitGenericArguments, boundType,
1665+
PackType::get(CS.getASTContext(), specializationArgTypes),
1666+
locator);
1667+
}
1668+
16451669
Type visitUnresolvedSpecializeExpr(UnresolvedSpecializeExpr *expr) {
16461670
auto baseTy = CS.getType(expr->getSubExpr());
16471671

@@ -3678,6 +3702,13 @@ namespace {
36783702
auto macroRefType = Type(CS.createTypeVariable(locator, 0));
36793703
CS.addOverloadSet(macroRefType, macros, CurDC, locator);
36803704

3705+
// Add explicit generic arguments, if there were any.
3706+
if (expr->getGenericArgsRange().isValid()) {
3707+
addSpecializationConstraint(
3708+
CS.getConstraintLocator(expr), macroRefType,
3709+
expr->getGenericArgs());
3710+
}
3711+
36813712
// For non-calls, the type variable is the result.
36823713
if (!isCall)
36833714
return macroRefType;

lib/Sema/CSSimplify.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,6 +2306,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
23062306
case ConstraintKind::BindTupleOfFunctionParams:
23072307
case ConstraintKind::PackElementOf:
23082308
case ConstraintKind::ShapeOf:
2309+
case ConstraintKind::ExplicitGenericArguments:
23092310
llvm_unreachable("Bad constraint kind in matchTupleTypes()");
23102311
}
23112312

@@ -2666,6 +2667,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1,
26662667
case ConstraintKind::BindTupleOfFunctionParams:
26672668
case ConstraintKind::PackElementOf:
26682669
case ConstraintKind::ShapeOf:
2670+
case ConstraintKind::ExplicitGenericArguments:
26692671
return true;
26702672
}
26712673

@@ -3084,6 +3086,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
30843086
case ConstraintKind::BindTupleOfFunctionParams:
30853087
case ConstraintKind::PackElementOf:
30863088
case ConstraintKind::ShapeOf:
3089+
case ConstraintKind::ExplicitGenericArguments:
30873090
llvm_unreachable("Not a relational constraint");
30883091
}
30893092

@@ -6459,6 +6462,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
64596462
case ConstraintKind::BindTupleOfFunctionParams:
64606463
case ConstraintKind::PackElementOf:
64616464
case ConstraintKind::ShapeOf:
6465+
case ConstraintKind::ExplicitGenericArguments:
64626466
llvm_unreachable("Not a relational constraint");
64636467
}
64646468
}
@@ -12644,6 +12648,95 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyShapeOfConstraint(
1264412648
return SolutionKind::Solved;
1264512649
}
1264612650

12651+
ConstraintSystem::SolutionKind
12652+
ConstraintSystem::simplifyExplicitGenericArgumentsConstraint(
12653+
Type type1, Type type2, TypeMatchOptions flags,
12654+
ConstraintLocatorBuilder locator) {
12655+
auto formUnsolved = [&]() {
12656+
// If we're supposed to generate constraints, do so.
12657+
if (flags.contains(TMF_GenerateConstraints)) {
12658+
auto *shapeOf = Constraint::create(
12659+
*this, ConstraintKind::ShapeOf, type1, type2,
12660+
getConstraintLocator(locator));
12661+
12662+
addUnsolvedConstraint(shapeOf);
12663+
return SolutionKind::Solved;
12664+
}
12665+
12666+
return SolutionKind::Unsolved;
12667+
};
12668+
12669+
// Bail out if we haven't selected an overload yet.
12670+
auto simplifiedBoundType = simplifyType(type1, flags);
12671+
if (simplifiedBoundType->isTypeVariableOrMember())
12672+
return formUnsolved();
12673+
12674+
// Determine the overload locator for this constraint.
12675+
ConstraintLocator *overloadLocator = nullptr;
12676+
if (auto anchorExpr = locator.getAnchor().dyn_cast<Expr *>()) {
12677+
if (auto expansion = dyn_cast<MacroExpansionExpr>(anchorExpr)) {
12678+
overloadLocator = getConstraintLocator(expansion);
12679+
} else if (auto specialize =
12680+
dyn_cast<UnresolvedSpecializeExpr>(anchorExpr)) {
12681+
overloadLocator = getConstraintLocator(
12682+
specialize->getSubExpr()->getSemanticsProvidingExpr());
12683+
}
12684+
} else if (auto anchorDecl = locator.getAnchor().dyn_cast<Decl *>()) {
12685+
if (auto expansion = dyn_cast<MacroExpansionDecl>(anchorDecl)) {
12686+
overloadLocator = getConstraintLocator(expansion);
12687+
}
12688+
}
12689+
assert(overloadLocator && "Specialize expression has the wrong form");
12690+
12691+
// If the overload hasn't been resolved, we can't simplify this constraint.
12692+
auto resolvedOverloadIter = getResolvedOverloads().find(overloadLocator);
12693+
if (resolvedOverloadIter == getResolvedOverloads().end())
12694+
return formUnsolved();
12695+
12696+
auto selectedOverload = resolvedOverloadIter->second;
12697+
auto overloadChoice = selectedOverload.choice;
12698+
if (!overloadChoice.isDecl()) {
12699+
return SolutionKind::Error;
12700+
}
12701+
12702+
auto decl = overloadChoice.getDecl();
12703+
auto genericContext = decl->getAsGenericContext();
12704+
if (!genericContext)
12705+
return SolutionKind::Error;
12706+
12707+
auto genericParams = genericContext->getGenericParams();
12708+
if (!genericParams || genericParams->size() == 0) {
12709+
// FIXME: Record an error here that we're ignoring the parameters.
12710+
return SolutionKind::Solved;
12711+
}
12712+
12713+
// Map the generic parameters we have over to their opened types.
12714+
SmallVector<Type, 2> openedGenericParams;
12715+
auto genericParamDepth = genericParams->getParams()[0]->getDepth();
12716+
for (const auto &openedType : getOpenedTypes(overloadLocator)) {
12717+
if (openedType.first->getDepth() == genericParamDepth) {
12718+
openedGenericParams.push_back(Type(openedType.second));
12719+
}
12720+
}
12721+
assert(openedGenericParams.size() == genericParams->size());
12722+
12723+
// Match the opened generic parameters to the specialized arguments.
12724+
auto specializedArgs = type2->castTo<PackType>()->getElementTypes();
12725+
PackMatcher matcher(openedGenericParams, specializedArgs, getASTContext());
12726+
if (matcher.match())
12727+
return SolutionKind::Error;
12728+
12729+
// Bind the opened generic parameters to the specialization arguments.
12730+
for (const auto &pair : matcher.pairs) {
12731+
addConstraint(
12732+
ConstraintKind::Bind, pair.lhs, pair.rhs,
12733+
getConstraintLocator(
12734+
locator, LocatorPathElt::GenericArgument(pair.idx)));
12735+
}
12736+
12737+
return SolutionKind::Solved;
12738+
}
12739+
1264712740
static llvm::PointerIntPair<Type, 3, unsigned>
1264812741
getBaseTypeForPointer(TypeBase *type) {
1264912742
unsigned unwrapCount = 0;
@@ -13993,6 +14086,10 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
1399314086
case ConstraintKind::ShapeOf:
1399414087
return simplifyShapeOfConstraint(first, second, subflags, locator);
1399514088

14089+
case ConstraintKind::ExplicitGenericArguments:
14090+
return simplifyExplicitGenericArgumentsConstraint(
14091+
first, second, subflags, locator);
14092+
1399614093
case ConstraintKind::ValueMember:
1399714094
case ConstraintKind::UnresolvedValueMember:
1399814095
case ConstraintKind::ValueWitness:
@@ -14583,6 +14680,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
1458314680
return simplifyShapeOfConstraint(
1458414681
constraint.getFirstType(), constraint.getSecondType(), /*flags*/ None,
1458514682
constraint.getLocator());
14683+
14684+
case ConstraintKind::ExplicitGenericArguments:
14685+
return simplifyExplicitGenericArgumentsConstraint(
14686+
constraint.getFirstType(), constraint.getSecondType(),
14687+
/*flags*/ None, constraint.getLocator());
1458614688
}
1458714689

1458814690
llvm_unreachable("Unhandled ConstraintKind in switch.");

lib/Sema/Constraint.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
8181
case ConstraintKind::BindTupleOfFunctionParams:
8282
case ConstraintKind::PackElementOf:
8383
case ConstraintKind::ShapeOf:
84+
case ConstraintKind::ExplicitGenericArguments:
8485
assert(!First.isNull());
8586
assert(!Second.isNull());
8687
break;
@@ -169,6 +170,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third,
169170
case ConstraintKind::BindTupleOfFunctionParams:
170171
case ConstraintKind::PackElementOf:
171172
case ConstraintKind::ShapeOf:
173+
case ConstraintKind::ExplicitGenericArguments:
172174
llvm_unreachable("Wrong constructor");
173175

174176
case ConstraintKind::KeyPath:
@@ -316,6 +318,7 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const {
316318
case ConstraintKind::BindTupleOfFunctionParams:
317319
case ConstraintKind::PackElementOf:
318320
case ConstraintKind::ShapeOf:
321+
case ConstraintKind::ExplicitGenericArguments:
319322
return create(cs, getKind(), getFirstType(), getSecondType(), getLocator());
320323

321324
case ConstraintKind::ApplicableFunction:
@@ -560,6 +563,10 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm, unsigned inden
560563
Out << " shape of ";
561564
break;
562565

566+
case ConstraintKind::ExplicitGenericArguments:
567+
Out << " explicit generic argument binding ";
568+
break;
569+
563570
case ConstraintKind::Disjunction:
564571
llvm_unreachable("disjunction handled above");
565572
case ConstraintKind::Conjunction:
@@ -726,6 +733,7 @@ gatherReferencedTypeVars(Constraint *constraint,
726733
case ConstraintKind::BindTupleOfFunctionParams:
727734
case ConstraintKind::PackElementOf:
728735
case ConstraintKind::ShapeOf:
736+
case ConstraintKind::ExplicitGenericArguments:
729737
constraint->getFirstType()->getTypeVariables(typeVars);
730738
constraint->getSecondType()->getTypeVariables(typeVars);
731739
break;

lib/Sema/TypeCheckMacros.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "swift/AST/ASTContext.h"
2020
#include "swift/AST/CompilerPlugin.h"
2121
#include "swift/AST/Expr.h"
22+
#include "swift/AST/NameLookupRequests.h"
2223
#include "swift/AST/PrettyStackTrace.h"
2324
#include "swift/AST/SourceFile.h"
2425
#include "swift/AST/TypeCheckRequests.h"
@@ -100,7 +101,7 @@ getMacroSignatureContextBuffer(
100101

101102
/// Compute the macro signature for a macro given the source code for its
102103
/// generic signature and type signature.
103-
static Optional<std::pair<GenericSignature, Type>>
104+
static Optional<std::tuple<GenericParamList *, GenericSignature, Type>>
104105
getMacroSignature(
105106
ModuleDecl *mod, Identifier macroName,
106107
Optional<StringRef> genericSignature,
@@ -161,8 +162,9 @@ getMacroSignature(
161162
performImportResolution(*macroSourceFile);
162163

163164
auto typealias = cast<TypeAliasDecl>(decl);
164-
return std::make_pair(
165-
typealias->getGenericSignature(), typealias->getUnderlyingType());
165+
return std::make_tuple(
166+
typealias->getGenericParams(), typealias->getGenericSignature(),
167+
typealias->getUnderlyingType());
166168
}
167169

168170
/// Create a macro.
@@ -212,8 +214,9 @@ static MacroDecl *createMacro(
212214
opaqueHandle);
213215

214216
// FIXME: Make these lazily computed.
215-
macro->setGenericSignature(signature->first);
216-
macro->setInterfaceType(signature->second);
217+
GenericParamListRequest{macro}.cacheResult(std::get<0>(*signature));
218+
macro->setGenericSignature(std::get<1>(*signature));
219+
macro->setInterfaceType(std::get<2>(*signature));
217220

218221
return macro;
219222
}

test/Macros/macros.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func test(a: Int, b: Int) {
1414
// CHECK: tuple_expr type='(() -> Bool, String)' location=Macro expansion of #stringify
1515

1616
let (b2, s3) = #stringify<Double>(1 + 2)
17-
// CHECK: macro_expansion_expr type='(Int, String)'{{.*}}name=stringify
17+
// CHECK: macro_expansion_expr type='(Double, String)'{{.*}}name=stringify
1818
// CHECK-NEXT: argument_list
19-
// CHECK: tuple_expr type='(Int, String)' location=Macro expansion of #stringify
19+
// CHECK: tuple_expr type='(Double, String)' location=Macro expansion of #stringify
2020
}

0 commit comments

Comments
 (0)