Skip to content

Commit 51b2481

Browse files
xedinhamishknight
authored andcommitted
[AST] Expand TypeJoin expression to support joining over a type
(cherry picked from commit 4efc35a)
1 parent 05a8aa4 commit 51b2481

File tree

6 files changed

+61
-15
lines changed

6 files changed

+61
-15
lines changed

include/swift/AST/Expr.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6016,11 +6016,24 @@ class TypeJoinExpr final : public Expr,
60166016
return { getTrailingObjects<Expr *>(), getNumElements() };
60176017
}
60186018

6019-
TypeJoinExpr(DeclRefExpr *var, ArrayRef<Expr *> elements);
6019+
TypeJoinExpr(llvm::PointerUnion<DeclRefExpr *, TypeBase *> result,
6020+
ArrayRef<Expr *> elements);
6021+
6022+
static TypeJoinExpr *
6023+
createImpl(ASTContext &ctx,
6024+
llvm::PointerUnion<DeclRefExpr *, TypeBase *> varOrType,
6025+
ArrayRef<Expr *> elements);
60206026

60216027
public:
60226028
static TypeJoinExpr *create(ASTContext &ctx, DeclRefExpr *var,
6023-
ArrayRef<Expr *> exprs);
6029+
ArrayRef<Expr *> exprs) {
6030+
return createImpl(ctx, var, exprs);
6031+
}
6032+
6033+
static TypeJoinExpr *create(ASTContext &ctx, Type joinType,
6034+
ArrayRef<Expr *> exprs) {
6035+
return createImpl(ctx, joinType.getPointer(), exprs);
6036+
}
60246037

60256038
SourceLoc getLoc() const { return SourceLoc(); }
60266039
SourceRange getSourceRange() const { return SourceRange(); }

lib/AST/ASTDumper.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,8 +3006,12 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
30063006
void visitTypeJoinExpr(TypeJoinExpr *E) {
30073007
printCommon(E, "type_join_expr");
30083008

3009-
PrintWithColorRAII(OS, DeclColor) << " var=";
3010-
printRec(E->getVar());
3009+
if (auto *var = E->getVar()) {
3010+
PrintWithColorRAII(OS, DeclColor) << " var=";
3011+
printRec(var);
3012+
OS << '\n';
3013+
}
3014+
30113015
OS << '\n';
30123016

30133017
for (auto *member : E->getElements()) {

lib/AST/ASTWalker.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,10 +1265,12 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
12651265
}
12661266

12671267
Expr *visitTypeJoinExpr(TypeJoinExpr *E) {
1268-
if (auto *newVar = dyn_cast<DeclRefExpr>(doIt(E->getVar()))) {
1269-
E->setVar(newVar);
1270-
} else {
1271-
return nullptr;
1268+
if (auto *var = E->getVar()) {
1269+
if (auto *newVar = dyn_cast<DeclRefExpr>(doIt(var))) {
1270+
E->setVar(newVar);
1271+
} else {
1272+
return nullptr;
1273+
}
12721274
}
12731275

12741276
for (unsigned i = 0, e = E->getNumElements(); i != e; ++i) {

lib/AST/Expr.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,21 +2519,31 @@ RegexLiteralExpr::createParsed(ASTContext &ctx, SourceLoc loc,
25192519
/*implicit*/ false);
25202520
}
25212521

2522-
TypeJoinExpr::TypeJoinExpr(DeclRefExpr *varRef, ArrayRef<Expr *> elements)
2523-
: Expr(ExprKind::TypeJoin, /*implicit=*/true, Type()), Var(varRef) {
2524-
assert(Var);
2522+
TypeJoinExpr::TypeJoinExpr(llvm::PointerUnion<DeclRefExpr *, TypeBase *> result,
2523+
ArrayRef<Expr *> elements)
2524+
: Expr(ExprKind::TypeJoin, /*implicit=*/true, Type()), Var(nullptr) {
2525+
2526+
if (auto *varRef = result.dyn_cast<DeclRefExpr *>()) {
2527+
assert(varRef);
2528+
Var = varRef;
2529+
} else {
2530+
auto joinType = Type(result.get<TypeBase *>());
2531+
assert(joinType && "expected non-null type");
2532+
setType(joinType);
2533+
}
25252534

25262535
Bits.TypeJoinExpr.NumElements = elements.size();
25272536
// Copy elements.
25282537
std::uninitialized_copy(elements.begin(), elements.end(),
25292538
getTrailingObjects<Expr *>());
25302539
}
25312540

2532-
TypeJoinExpr *TypeJoinExpr::create(ASTContext &ctx, DeclRefExpr *var,
2533-
ArrayRef<Expr *> elements) {
2541+
TypeJoinExpr *TypeJoinExpr::createImpl(
2542+
ASTContext &ctx, llvm::PointerUnion<DeclRefExpr *, TypeBase *> varOrType,
2543+
ArrayRef<Expr *> elements) {
25342544
size_t size = totalSizeToAlloc<Expr *>(elements.size());
25352545
void *mem = ctx.Allocate(size, alignof(TypeJoinExpr));
2536-
return new (mem) TypeJoinExpr(var, elements);
2546+
return new (mem) TypeJoinExpr(varOrType, elements);
25372547
}
25382548

25392549
SourceRange MacroExpansionExpr::getSourceRange() const {

lib/Sema/CSGen.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3747,7 +3747,16 @@ namespace {
37473747
CS.getConstraintLocator(element));
37483748
}
37493749

3750-
auto resultTy = CS.getType(expr->getVar());
3750+
Type resultTy;
3751+
3752+
if (auto *var = expr->getVar()) {
3753+
resultTy = CS.getType(var);
3754+
} else {
3755+
resultTy = expr->getType();
3756+
}
3757+
3758+
assert(resultTy);
3759+
37513760
// The type of a join expression is obtained by performing
37523761
// a "join-meet" operation on deduced types of its elements
37533762
// and the underlying variable.

lib/Sema/CSSyntacticElement.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ class TypeVariableRefFinder : public ASTWalker {
6161
ClosureDCs.push_back(closure);
6262
}
6363

64+
if (auto *joinExpr = dyn_cast<TypeJoinExpr>(expr)) {
65+
// If this join is over a known type, let's
66+
// analyze it too because it can contain type
67+
// variables.
68+
if (!joinExpr->getVar())
69+
inferVariables(joinExpr->getType());
70+
}
71+
6472
if (auto *DRE = dyn_cast<DeclRefExpr>(expr)) {
6573
auto *decl = DRE->getDecl();
6674

0 commit comments

Comments
 (0)