Skip to content

Commit ea7c30b

Browse files
committed
Simulate template argument pack expansion
1 parent 8e012c8 commit ea7c30b

File tree

11 files changed

+208
-5
lines changed

11 files changed

+208
-5
lines changed

src/parser/cxx/ast_rewriter.cc

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
// cxx
2424
#include <cxx/ast.h>
25+
#include <cxx/ast_cursor.h>
2526
#include <cxx/binder.h>
2627
#include <cxx/control.h>
2728
#include <cxx/decl.h>
@@ -31,6 +32,8 @@
3132
#include <cxx/type_checker.h>
3233
#include <cxx/types.h>
3334

35+
#include <format>
36+
3437
namespace cxx {
3538

3639
ASTRewriter::ASTRewriter(TypeChecker* typeChcker,
@@ -54,6 +57,29 @@ void ASTRewriter::setRestrictedToDeclarations(bool restrictedToDeclarations) {
5457
restrictedToDeclarations_ = restrictedToDeclarations;
5558
}
5659

60+
auto ASTRewriter::getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol* {
61+
for (auto cursor = ASTCursor{ast, {}}; cursor; ++cursor) {
62+
const auto& current = *cursor;
63+
if (!std::holds_alternative<AST*>(current.node)) continue;
64+
65+
auto id = ast_cast<IdExpressionAST>(std::get<AST*>(current.node));
66+
if (!id) continue;
67+
68+
auto param = symbol_cast<NonTypeParameterSymbol>(id->symbol);
69+
if (!param) continue;
70+
71+
if (param->depth() != 0) continue;
72+
73+
auto arg = templateArguments_[param->index()];
74+
auto argSymbol = std::get<Symbol*>(arg);
75+
76+
auto parameterPack = symbol_cast<ParameterPackSymbol>(argSymbol);
77+
if (parameterPack) return parameterPack;
78+
}
79+
80+
return nullptr;
81+
}
82+
5783
struct ASTRewriter::UnitVisitor {
5884
ASTRewriter& rewrite;
5985
[[nodiscard]] auto translationUnit() const -> TranslationUnit* {
@@ -2355,16 +2381,27 @@ auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
23552381

23562382
copy->symbol = ast->symbol;
23572383

2358-
if (auto x = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
2359-
x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
2384+
if (auto param = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
2385+
param && param->depth() == 0 &&
2386+
param->index() < rewrite.templateArguments_.size()) {
23602387
auto initializerPtr =
2361-
std::get_if<Symbol*>(&rewrite.templateArguments_[x->index()]);
2388+
std::get_if<Symbol*>(&rewrite.templateArguments_[param->index()]);
2389+
23622390
if (!initializerPtr) {
23632391
cxx_runtime_error("expected initializer for non-type template parameter");
23642392
}
23652393

23662394
copy->symbol = *initializerPtr;
23672395
copy->type = copy->symbol->type();
2396+
2397+
auto parameterPack = symbol_cast<ParameterPackSymbol>(copy->symbol);
2398+
if (parameterPack && rewrite.elementIndex_.has_value()) {
2399+
// ### TODO: check that idx is related to this non-type parameter.
2400+
auto idx = rewrite.elementIndex_.value();
2401+
auto element = parameterPack->elements()[idx];
2402+
copy->symbol = element;
2403+
copy->type = element->type();
2404+
}
23682405
}
23692406

23702407
copy->isTemplateIntroduced = ast->isTemplateIntroduced;
@@ -2473,6 +2510,33 @@ auto ASTRewriter::ExpressionVisitor::operator()(RightFoldExpressionAST* ast)
24732510

24742511
auto ASTRewriter::ExpressionVisitor::operator()(LeftFoldExpressionAST* ast)
24752512
-> ExpressionAST* {
2513+
if (auto parameterPack = rewrite.getParameterPack(ast->expression)) {
2514+
std::vector<ExpressionAST*> instantiations;
2515+
ExpressionAST* current = nullptr;
2516+
int n = 0;
2517+
for (auto element : parameterPack->elements()) {
2518+
std::optional<int> index{n};
2519+
std::swap(rewrite.elementIndex_, index);
2520+
auto expression = rewrite(ast->expression);
2521+
if (!current) {
2522+
current = expression;
2523+
} else {
2524+
auto binop = make_node<BinaryExpressionAST>(arena());
2525+
binop->valueCategory = current->valueCategory;
2526+
binop->type = current->type;
2527+
binop->leftExpression = current;
2528+
binop->op = ast->op;
2529+
binop->opLoc = ast->opLoc;
2530+
binop->rightExpression = expression;
2531+
current = binop;
2532+
}
2533+
std::swap(rewrite.elementIndex_, index);
2534+
++n;
2535+
}
2536+
2537+
return current;
2538+
}
2539+
24762540
auto copy = make_node<LeftFoldExpressionAST>(arena());
24772541

24782542
copy->valueCategory = ast->valueCategory;

src/parser/cxx/ast_rewriter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,12 @@ class ASTRewriter {
151151
private:
152152
[[nodiscard]] auto rewriter() -> ASTRewriter* { return this; }
153153

154+
[[nodiscard]] auto getParameterPack(ExpressionAST* ast)
155+
-> ParameterPackSymbol*;
156+
154157
TypeChecker* typeChecker_ = nullptr;
155158
const std::vector<TemplateArgument>& templateArguments_;
159+
std::optional<int> elementIndex_;
156160
TranslationUnit* unit_ = nullptr;
157161
Binder binder_;
158162
bool restrictedToDeclarations_ = false;

src/parser/cxx/control.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ struct Control::Private {
143143
std::forward_list<VariableSymbol> variableSymbols;
144144
std::forward_list<FieldSymbol> fieldSymbols;
145145
std::forward_list<ParameterSymbol> parameterSymbols;
146+
std::forward_list<ParameterPackSymbol> parameterPackSymbols;
146147
std::forward_list<TypeParameterSymbol> typeParameterSymbols;
147148
std::forward_list<NonTypeParameterSymbol> nonTypeParameterSymbols;
148149
std::forward_list<TemplateTypeParameterSymbol> templateTypeParameterSymbols;
@@ -592,6 +593,13 @@ auto Control::newParameterSymbol(Scope* enclosingScope, SourceLocation loc)
592593
return symbol;
593594
}
594595

596+
auto Control::newParameterPackSymbol(Scope* enclosingScope, SourceLocation loc)
597+
-> ParameterPackSymbol* {
598+
auto symbol = &d->parameterPackSymbols.emplace_front(enclosingScope);
599+
symbol->setLocation(loc);
600+
return symbol;
601+
}
602+
595603
auto Control::newTypeParameterSymbol(Scope* enclosingScope, SourceLocation loc)
596604
-> TypeParameterSymbol* {
597605
auto symbol = &d->typeParameterSymbols.emplace_front(enclosingScope);

src/parser/cxx/control.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ class Control {
202202
[[nodiscard]] auto newParameterSymbol(Scope* enclosingScope,
203203
SourceLocation sourceLocation)
204204
-> ParameterSymbol*;
205+
[[nodiscard]] auto newParameterPackSymbol(Scope* enclosingScope,
206+
SourceLocation sourceLocation)
207+
-> ParameterPackSymbol*;
205208
[[nodiscard]] auto newTypeParameterSymbol(Scope* enclosingScope,
206209
SourceLocation sourceLocation)
207210
-> TypeParameterSymbol*;

src/parser/cxx/external_name_encoder.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ struct ExternalNameEncoder::SymbolVisitor {
156156

157157
void operator()(ParameterSymbol* symbol) {}
158158

159+
void operator()(ParameterPackSymbol* symbol) {}
160+
159161
void operator()(EnumeratorSymbol* symbol) {}
160162

161163
void operator()(FunctionParametersSymbol* symbol) {}

src/parser/cxx/symbol_instantiation.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct SymbolInstantiation::VisitSymbol {
6262
[[nodiscard]] auto operator()(VariableSymbol* symbol) -> Symbol*;
6363
[[nodiscard]] auto operator()(FieldSymbol* symbol) -> Symbol*;
6464
[[nodiscard]] auto operator()(ParameterSymbol* symbol) -> Symbol*;
65+
[[nodiscard]] auto operator()(ParameterPackSymbol* symbol) -> Symbol*;
6566
[[nodiscard]] auto operator()(EnumeratorSymbol* symbol) -> Symbol*;
6667
[[nodiscard]] auto operator()(FunctionParametersSymbol* symbol) -> Symbol*;
6768
[[nodiscard]] auto operator()(TemplateParametersSymbol* symbol) -> Symbol*;
@@ -312,6 +313,12 @@ auto SymbolInstantiation::VisitSymbol::operator()(ParameterSymbol* symbol)
312313
return newSymbol;
313314
}
314315

316+
auto SymbolInstantiation::VisitSymbol::operator()(ParameterPackSymbol* symbol)
317+
-> Symbol* {
318+
auto newSymbol = self.replacement(symbol);
319+
return newSymbol;
320+
}
321+
315322
auto SymbolInstantiation::VisitSymbol::operator()(EnumeratorSymbol* symbol)
316323
-> Symbol* {
317324
auto newSymbol = self.replacement(symbol);

src/parser/cxx/symbol_printer.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ struct DumpSymbols {
269269
to_string(symbol->type(), symbol->name()));
270270
}
271271

272+
void operator()(ParameterPackSymbol* symbol) {
273+
indent();
274+
out << std::format("parameter pack {}\n",
275+
to_string(symbol->type(), symbol->name()));
276+
}
277+
272278
void operator()(TypeParameterSymbol* symbol) {
273279
std::string_view pack = symbol->isParameterPack() ? "..." : "";
274280
indent();

src/parser/cxx/symbols.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,19 @@ ParameterSymbol::ParameterSymbol(Scope* enclosingScope)
615615

616616
ParameterSymbol::~ParameterSymbol() {}
617617

618+
ParameterPackSymbol::ParameterPackSymbol(Scope* enclosingScope)
619+
: Symbol(Kind, enclosingScope) {}
620+
621+
ParameterPackSymbol::~ParameterPackSymbol() {}
622+
623+
auto ParameterPackSymbol::elements() const -> const std::vector<Symbol*>& {
624+
return elements_;
625+
}
626+
627+
void ParameterPackSymbol::addElement(Symbol* element) {
628+
elements_.push_back(element);
629+
}
630+
618631
TypeParameterSymbol::TypeParameterSymbol(Scope* enclosingScope)
619632
: Symbol(Kind, enclosingScope) {}
620633

src/parser/cxx/symbols.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,20 @@ class ParameterSymbol final : public Symbol {
624624
~ParameterSymbol() override;
625625
};
626626

627+
class ParameterPackSymbol final : public Symbol {
628+
public:
629+
constexpr static auto Kind = SymbolKind::kParameterPack;
630+
631+
explicit ParameterPackSymbol(Scope* enclosingScope);
632+
~ParameterPackSymbol() override;
633+
634+
[[nodiscard]] auto elements() const -> const std::vector<Symbol*>&;
635+
void addElement(Symbol* element);
636+
637+
private:
638+
std::vector<Symbol*> elements_;
639+
};
640+
627641
class TypeParameterSymbol final : public Symbol {
628642
public:
629643
constexpr static auto Kind = SymbolKind::kTypeParameter;

src/parser/cxx/symbols_fwd.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace cxx {
3737
V(Variable) \
3838
V(Field) \
3939
V(Parameter) \
40+
V(ParameterPack) \
4041
V(Enumerator) \
4142
V(FunctionParameters) \
4243
V(TemplateParameters) \

0 commit comments

Comments
 (0)